use integers to calculate formulas for beancount amount formula
This commit is contained in:
@@ -1,15 +1,20 @@
|
|||||||
package beancount
|
package beancount
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"math/big"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||||
|
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxAllowedDecimalCount = 6
|
||||||
|
const normalizeFactor = int64(1000000)
|
||||||
|
const normalizedDecimalsMaxZeroString = "000000"
|
||||||
|
const normalizedNumberToAmountFactor = int64(10000) // 1000000 / 100
|
||||||
|
|
||||||
var operatorPriority = map[rune]int{
|
var operatorPriority = map[rune]int{
|
||||||
'+': 1,
|
'+': 1,
|
||||||
'-': 1,
|
'-': 1,
|
||||||
@@ -17,6 +22,44 @@ var operatorPriority = map[rune]int{
|
|||||||
'/': 2,
|
'/': 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func normalizeNumber(textualNumber string) (*big.Int, error) {
|
||||||
|
decimalSeparatorPos := strings.Index(textualNumber, ".")
|
||||||
|
|
||||||
|
if decimalSeparatorPos < 0 {
|
||||||
|
result := big.NewInt(0)
|
||||||
|
_, ok := result.SetString(textualNumber+normalizedDecimalsMaxZeroString, 10)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, errs.ErrAmountInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
integer := utils.SubString(textualNumber, 0, decimalSeparatorPos)
|
||||||
|
decimals := utils.SubString(textualNumber, decimalSeparatorPos+1, len(textualNumber))
|
||||||
|
|
||||||
|
if len(decimals) > maxAllowedDecimalCount {
|
||||||
|
return nil, errs.ErrAmountInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
paddedDecimals := utils.SubString(decimals+normalizedDecimalsMaxZeroString, 0, maxAllowedDecimalCount)
|
||||||
|
result := big.NewInt(0)
|
||||||
|
_, ok := result.SetString(integer+paddedDecimals, 10)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, errs.ErrAmountInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func denormalizeNumberToTextualAmount(num *big.Int) string {
|
||||||
|
result := big.NewInt(0).Add(num, big.NewInt(0)) // make a copy of num
|
||||||
|
result = result.Div(result, big.NewInt(normalizedNumberToAmountFactor))
|
||||||
|
return utils.FormatAmount(result.Int64())
|
||||||
|
}
|
||||||
|
|
||||||
func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
|
func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
|
||||||
finalTokens := make([]string, 0)
|
finalTokens := make([]string, 0)
|
||||||
operatorStack := make([]rune, 0)
|
operatorStack := make([]rune, 0)
|
||||||
@@ -117,8 +160,8 @@ func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
|
|||||||
return finalTokens, nil
|
return finalTokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
|
func evaluatePostfixExpr(ctx core.Context, tokens []string) (*big.Int, error) {
|
||||||
stack := make([]float64, 0)
|
stack := make([]*big.Int, 0)
|
||||||
|
|
||||||
for i := 0; i < len(tokens); i++ {
|
for i := 0; i < len(tokens); i++ {
|
||||||
token := tokens[i]
|
token := tokens[i]
|
||||||
@@ -127,7 +170,7 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
|
|||||||
case "+", "-", "*", "/": // operators
|
case "+", "-", "*", "/": // operators
|
||||||
if len(stack) < 2 {
|
if len(stack) < 2 {
|
||||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " "))
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " "))
|
||||||
return 0, errs.ErrInvalidAmountExpression
|
return nil, errs.ErrInvalidAmountExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
// pop the top two operands
|
// pop the top two operands
|
||||||
@@ -138,39 +181,41 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
|
|||||||
stack = stack[:len(stack)-1]
|
stack = stack[:len(stack)-1]
|
||||||
|
|
||||||
// evaluate the operation
|
// evaluate the operation
|
||||||
var result float64
|
result := big.NewInt(0)
|
||||||
switch token {
|
switch token {
|
||||||
case "+":
|
case "+":
|
||||||
result = a + b
|
result.Add(a, b)
|
||||||
case "-":
|
case "-":
|
||||||
result = a - b
|
result.Sub(a, b)
|
||||||
case "*":
|
case "*":
|
||||||
result = a * b
|
result.Mul(a, b)
|
||||||
|
result.Div(result, big.NewInt(normalizeFactor))
|
||||||
case "/":
|
case "/":
|
||||||
if b == 0 {
|
if b.Int64() == 0 {
|
||||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " "))
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " "))
|
||||||
return 0, errs.ErrInvalidAmountExpression
|
return nil, errs.ErrInvalidAmountExpression
|
||||||
}
|
}
|
||||||
result = a / b
|
result.Mul(a, big.NewInt(normalizeFactor))
|
||||||
|
result.Div(result, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// push the result back to the stack
|
// push the result back to the stack
|
||||||
stack = append(stack, result)
|
stack = append(stack, result)
|
||||||
default: // operands
|
default: // operands
|
||||||
num, err := strconv.ParseFloat(token, 64)
|
normalizedNum, err := normalizeNumber(token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " "))
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " "))
|
||||||
return 0, errs.ErrInvalidAmountExpression
|
return nil, errs.ErrInvalidAmountExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
stack = append(stack, num)
|
stack = append(stack, normalizedNum)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(stack) != 1 {
|
if len(stack) != 1 {
|
||||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " "))
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " "))
|
||||||
return 0, errs.ErrInvalidAmountExpression
|
return nil, errs.ErrInvalidAmountExpression
|
||||||
}
|
}
|
||||||
|
|
||||||
return stack[0], nil
|
return stack[0], nil
|
||||||
@@ -193,5 +238,5 @@ func evaluateBeancountAmountExpression(ctx core.Context, expr string) (string, e
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%.2f", result), nil
|
return denormalizeNumberToTextualAmount(result), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package beancount
|
package beancount
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"math/big"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -97,23 +98,23 @@ func TestEvaluatePostfixExpr_ValidExpression(t *testing.T) {
|
|||||||
|
|
||||||
result, err := evaluatePostfixExpr(context, []string{"1", "2", "+"})
|
result, err := evaluatePostfixExpr(context, []string{"1", "2", "+"})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, float64(3), result)
|
assert.Equal(t, big.NewInt(3000000), result)
|
||||||
|
|
||||||
result, err = evaluatePostfixExpr(context, []string{"5", "3", "-"})
|
result, err = evaluatePostfixExpr(context, []string{"5", "3", "-"})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, float64(2), result)
|
assert.Equal(t, big.NewInt(2000000), result)
|
||||||
|
|
||||||
result, err = evaluatePostfixExpr(context, []string{"4", "3", "*"})
|
result, err = evaluatePostfixExpr(context, []string{"4", "3", "*"})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, float64(12), result)
|
assert.Equal(t, big.NewInt(12000000), result)
|
||||||
|
|
||||||
result, err = evaluatePostfixExpr(context, []string{"6", "2", "/"})
|
result, err = evaluatePostfixExpr(context, []string{"6", "2", "/"})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, float64(3), result)
|
assert.Equal(t, big.NewInt(3000000), result)
|
||||||
|
|
||||||
result, err = evaluatePostfixExpr(context, []string{"1", "2", "3", "*", "+", "4", "2", "/", "-"})
|
result, err = evaluatePostfixExpr(context, []string{"1", "2", "3", "*", "+", "4", "2", "/", "-"})
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, float64(5), result)
|
assert.Equal(t, big.NewInt(5000000), result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEvaluatePostfixExpr_InvalidExpression(t *testing.T) {
|
func TestEvaluatePostfixExpr_InvalidExpression(t *testing.T) {
|
||||||
@@ -179,6 +180,18 @@ func TestEvaluateBeancountAmountExpression_ValidExpression(t *testing.T) {
|
|||||||
result, err = evaluateBeancountAmountExpression(context, "(((2+3)))*(((((-5+7)))))")
|
result, err = evaluateBeancountAmountExpression(context, "(((2+3)))*(((((-5+7)))))")
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "10.00", result)
|
assert.Equal(t, "10.00", result)
|
||||||
|
|
||||||
|
result, err = evaluateBeancountAmountExpression(context, "3.5+0.1")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "3.60", result)
|
||||||
|
|
||||||
|
result, err = evaluateBeancountAmountExpression(context, "3.55+0.11")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "3.66", result)
|
||||||
|
|
||||||
|
result, err = evaluateBeancountAmountExpression(context, "3.555+0.111")
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "3.66", result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
|
func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
|
||||||
@@ -213,4 +226,10 @@ func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
|
|||||||
|
|
||||||
_, err = evaluateBeancountAmountExpression(context, "1)*(2")
|
_, err = evaluateBeancountAmountExpression(context, "1)*(2")
|
||||||
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
||||||
|
|
||||||
|
_, err = evaluateBeancountAmountExpression(context, "0.abcd+1")
|
||||||
|
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
||||||
|
|
||||||
|
_, err = evaluateBeancountAmountExpression(context, "0.1234567+1")
|
||||||
|
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user