From bcf11631d67299fe0a101609c2c14771f6a72ec4 Mon Sep 17 00:00:00 2001 From: MaysWind Date: Mon, 1 Sep 2025 01:16:58 +0800 Subject: [PATCH] use integers to calculate formulas for beancount amount formula --- .../beancount_amount_expression_evaluator.go | 79 +++++++++++++++---- ...ncount_amount_expression_evaluator_test.go | 29 +++++-- 2 files changed, 86 insertions(+), 22 deletions(-) diff --git a/pkg/converters/beancount/beancount_amount_expression_evaluator.go b/pkg/converters/beancount/beancount_amount_expression_evaluator.go index c7ccea3b..c642e8ab 100644 --- a/pkg/converters/beancount/beancount_amount_expression_evaluator.go +++ b/pkg/converters/beancount/beancount_amount_expression_evaluator.go @@ -1,15 +1,20 @@ package beancount import ( - "fmt" - "strconv" + "math/big" "strings" "github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/utils" ) +const maxAllowedDecimalCount = 6 +const normalizeFactor = int64(1000000) +const normalizedDecimalsMaxZeroString = "000000" +const normalizedNumberToAmountFactor = int64(10000) // 1000000 / 100 + var operatorPriority = map[rune]int{ '+': 1, '-': 1, @@ -17,6 +22,44 @@ var operatorPriority = map[rune]int{ '/': 2, } +func normalizeNumber(textualNumber string) (*big.Int, error) { + decimalSeparatorPos := strings.Index(textualNumber, ".") + + if decimalSeparatorPos < 0 { + result := big.NewInt(0) + _, ok := result.SetString(textualNumber+normalizedDecimalsMaxZeroString, 10) + + if !ok { + return nil, errs.ErrAmountInvalid + } + + return result, nil + } + + integer := utils.SubString(textualNumber, 0, decimalSeparatorPos) + decimals := utils.SubString(textualNumber, decimalSeparatorPos+1, len(textualNumber)) + + if len(decimals) > maxAllowedDecimalCount { + return nil, errs.ErrAmountInvalid + } + + paddedDecimals := utils.SubString(decimals+normalizedDecimalsMaxZeroString, 0, maxAllowedDecimalCount) + result := big.NewInt(0) + _, ok := result.SetString(integer+paddedDecimals, 10) + + if !ok { + return nil, errs.ErrAmountInvalid + } + + return result, nil +} + +func denormalizeNumberToTextualAmount(num *big.Int) string { + result := big.NewInt(0).Add(num, big.NewInt(0)) // make a copy of num + result = result.Div(result, big.NewInt(normalizedNumberToAmountFactor)) + return utils.FormatAmount(result.Int64()) +} + func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) { finalTokens := make([]string, 0) operatorStack := make([]rune, 0) @@ -117,8 +160,8 @@ func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) { return finalTokens, nil } -func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) { - stack := make([]float64, 0) +func evaluatePostfixExpr(ctx core.Context, tokens []string) (*big.Int, error) { + stack := make([]*big.Int, 0) for i := 0; i < len(tokens); i++ { token := tokens[i] @@ -127,7 +170,7 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) { case "+", "-", "*", "/": // operators if len(stack) < 2 { log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " ")) - return 0, errs.ErrInvalidAmountExpression + return nil, errs.ErrInvalidAmountExpression } // pop the top two operands @@ -138,39 +181,41 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) { stack = stack[:len(stack)-1] // evaluate the operation - var result float64 + result := big.NewInt(0) switch token { case "+": - result = a + b + result.Add(a, b) case "-": - result = a - b + result.Sub(a, b) case "*": - result = a * b + result.Mul(a, b) + result.Div(result, big.NewInt(normalizeFactor)) case "/": - if b == 0 { + if b.Int64() == 0 { log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " ")) - return 0, errs.ErrInvalidAmountExpression + return nil, errs.ErrInvalidAmountExpression } - result = a / b + result.Mul(a, big.NewInt(normalizeFactor)) + result.Div(result, b) } // push the result back to the stack stack = append(stack, result) default: // operands - num, err := strconv.ParseFloat(token, 64) + normalizedNum, err := normalizeNumber(token) if err != nil { log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " ")) - return 0, errs.ErrInvalidAmountExpression + return nil, errs.ErrInvalidAmountExpression } - stack = append(stack, num) + stack = append(stack, normalizedNum) } } if len(stack) != 1 { log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " ")) - return 0, errs.ErrInvalidAmountExpression + return nil, errs.ErrInvalidAmountExpression } return stack[0], nil @@ -193,5 +238,5 @@ func evaluateBeancountAmountExpression(ctx core.Context, expr string) (string, e return "", err } - return fmt.Sprintf("%.2f", result), nil + return denormalizeNumberToTextualAmount(result), nil } diff --git a/pkg/converters/beancount/beancount_amount_expression_evaluator_test.go b/pkg/converters/beancount/beancount_amount_expression_evaluator_test.go index c49acaea..806041f3 100644 --- a/pkg/converters/beancount/beancount_amount_expression_evaluator_test.go +++ b/pkg/converters/beancount/beancount_amount_expression_evaluator_test.go @@ -1,6 +1,7 @@ package beancount import ( + "math/big" "testing" "github.com/stretchr/testify/assert" @@ -97,23 +98,23 @@ func TestEvaluatePostfixExpr_ValidExpression(t *testing.T) { result, err := evaluatePostfixExpr(context, []string{"1", "2", "+"}) assert.Nil(t, err) - assert.Equal(t, float64(3), result) + assert.Equal(t, big.NewInt(3000000), result) result, err = evaluatePostfixExpr(context, []string{"5", "3", "-"}) assert.Nil(t, err) - assert.Equal(t, float64(2), result) + assert.Equal(t, big.NewInt(2000000), result) result, err = evaluatePostfixExpr(context, []string{"4", "3", "*"}) assert.Nil(t, err) - assert.Equal(t, float64(12), result) + assert.Equal(t, big.NewInt(12000000), result) result, err = evaluatePostfixExpr(context, []string{"6", "2", "/"}) assert.Nil(t, err) - assert.Equal(t, float64(3), result) + assert.Equal(t, big.NewInt(3000000), result) result, err = evaluatePostfixExpr(context, []string{"1", "2", "3", "*", "+", "4", "2", "/", "-"}) assert.Nil(t, err) - assert.Equal(t, float64(5), result) + assert.Equal(t, big.NewInt(5000000), result) } func TestEvaluatePostfixExpr_InvalidExpression(t *testing.T) { @@ -179,6 +180,18 @@ func TestEvaluateBeancountAmountExpression_ValidExpression(t *testing.T) { result, err = evaluateBeancountAmountExpression(context, "(((2+3)))*(((((-5+7)))))") assert.Nil(t, err) assert.Equal(t, "10.00", result) + + result, err = evaluateBeancountAmountExpression(context, "3.5+0.1") + assert.Nil(t, err) + assert.Equal(t, "3.60", result) + + result, err = evaluateBeancountAmountExpression(context, "3.55+0.11") + assert.Nil(t, err) + assert.Equal(t, "3.66", result) + + result, err = evaluateBeancountAmountExpression(context, "3.555+0.111") + assert.Nil(t, err) + assert.Equal(t, "3.66", result) } func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) { @@ -213,4 +226,10 @@ func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) { _, err = evaluateBeancountAmountExpression(context, "1)*(2") assert.Equal(t, errs.ErrInvalidAmountExpression, err) + + _, err = evaluateBeancountAmountExpression(context, "0.abcd+1") + assert.Equal(t, errs.ErrInvalidAmountExpression, err) + + _, err = evaluateBeancountAmountExpression(context, "0.1234567+1") + assert.Equal(t, errs.ErrInvalidAmountExpression, err) }