use integers to calculate formulas for beancount amount formula

This commit is contained in:
MaysWind
2025-09-01 01:16:58 +08:00
parent 989183c8be
commit bcf11631d6
2 changed files with 86 additions and 22 deletions
@@ -1,15 +1,20 @@
package beancount package beancount
import ( import (
"fmt" "math/big"
"strconv"
"strings" "strings"
"github.com/mayswind/ezbookkeeping/pkg/core" "github.com/mayswind/ezbookkeeping/pkg/core"
"github.com/mayswind/ezbookkeeping/pkg/errs" "github.com/mayswind/ezbookkeeping/pkg/errs"
"github.com/mayswind/ezbookkeeping/pkg/log" "github.com/mayswind/ezbookkeeping/pkg/log"
"github.com/mayswind/ezbookkeeping/pkg/utils"
) )
const maxAllowedDecimalCount = 6
const normalizeFactor = int64(1000000)
const normalizedDecimalsMaxZeroString = "000000"
const normalizedNumberToAmountFactor = int64(10000) // 1000000 / 100
var operatorPriority = map[rune]int{ var operatorPriority = map[rune]int{
'+': 1, '+': 1,
'-': 1, '-': 1,
@@ -17,6 +22,44 @@ var operatorPriority = map[rune]int{
'/': 2, '/': 2,
} }
func normalizeNumber(textualNumber string) (*big.Int, error) {
decimalSeparatorPos := strings.Index(textualNumber, ".")
if decimalSeparatorPos < 0 {
result := big.NewInt(0)
_, ok := result.SetString(textualNumber+normalizedDecimalsMaxZeroString, 10)
if !ok {
return nil, errs.ErrAmountInvalid
}
return result, nil
}
integer := utils.SubString(textualNumber, 0, decimalSeparatorPos)
decimals := utils.SubString(textualNumber, decimalSeparatorPos+1, len(textualNumber))
if len(decimals) > maxAllowedDecimalCount {
return nil, errs.ErrAmountInvalid
}
paddedDecimals := utils.SubString(decimals+normalizedDecimalsMaxZeroString, 0, maxAllowedDecimalCount)
result := big.NewInt(0)
_, ok := result.SetString(integer+paddedDecimals, 10)
if !ok {
return nil, errs.ErrAmountInvalid
}
return result, nil
}
func denormalizeNumberToTextualAmount(num *big.Int) string {
result := big.NewInt(0).Add(num, big.NewInt(0)) // make a copy of num
result = result.Div(result, big.NewInt(normalizedNumberToAmountFactor))
return utils.FormatAmount(result.Int64())
}
func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) { func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
finalTokens := make([]string, 0) finalTokens := make([]string, 0)
operatorStack := make([]rune, 0) operatorStack := make([]rune, 0)
@@ -117,8 +160,8 @@ func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
return finalTokens, nil return finalTokens, nil
} }
func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) { func evaluatePostfixExpr(ctx core.Context, tokens []string) (*big.Int, error) {
stack := make([]float64, 0) stack := make([]*big.Int, 0)
for i := 0; i < len(tokens); i++ { for i := 0; i < len(tokens); i++ {
token := tokens[i] token := tokens[i]
@@ -127,7 +170,7 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
case "+", "-", "*", "/": // operators case "+", "-", "*", "/": // operators
if len(stack) < 2 { if len(stack) < 2 {
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " ")) log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " "))
return 0, errs.ErrInvalidAmountExpression return nil, errs.ErrInvalidAmountExpression
} }
// pop the top two operands // pop the top two operands
@@ -138,39 +181,41 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
stack = stack[:len(stack)-1] stack = stack[:len(stack)-1]
// evaluate the operation // evaluate the operation
var result float64 result := big.NewInt(0)
switch token { switch token {
case "+": case "+":
result = a + b result.Add(a, b)
case "-": case "-":
result = a - b result.Sub(a, b)
case "*": case "*":
result = a * b result.Mul(a, b)
result.Div(result, big.NewInt(normalizeFactor))
case "/": case "/":
if b == 0 { if b.Int64() == 0 {
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " ")) log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " "))
return 0, errs.ErrInvalidAmountExpression return nil, errs.ErrInvalidAmountExpression
} }
result = a / b result.Mul(a, big.NewInt(normalizeFactor))
result.Div(result, b)
} }
// push the result back to the stack // push the result back to the stack
stack = append(stack, result) stack = append(stack, result)
default: // operands default: // operands
num, err := strconv.ParseFloat(token, 64) normalizedNum, err := normalizeNumber(token)
if err != nil { if err != nil {
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " ")) log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " "))
return 0, errs.ErrInvalidAmountExpression return nil, errs.ErrInvalidAmountExpression
} }
stack = append(stack, num) stack = append(stack, normalizedNum)
} }
} }
if len(stack) != 1 { if len(stack) != 1 {
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " ")) log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " "))
return 0, errs.ErrInvalidAmountExpression return nil, errs.ErrInvalidAmountExpression
} }
return stack[0], nil return stack[0], nil
@@ -193,5 +238,5 @@ func evaluateBeancountAmountExpression(ctx core.Context, expr string) (string, e
return "", err return "", err
} }
return fmt.Sprintf("%.2f", result), nil return denormalizeNumberToTextualAmount(result), nil
} }
@@ -1,6 +1,7 @@
package beancount package beancount
import ( import (
"math/big"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -97,23 +98,23 @@ func TestEvaluatePostfixExpr_ValidExpression(t *testing.T) {
result, err := evaluatePostfixExpr(context, []string{"1", "2", "+"}) result, err := evaluatePostfixExpr(context, []string{"1", "2", "+"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, float64(3), result) assert.Equal(t, big.NewInt(3000000), result)
result, err = evaluatePostfixExpr(context, []string{"5", "3", "-"}) result, err = evaluatePostfixExpr(context, []string{"5", "3", "-"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, float64(2), result) assert.Equal(t, big.NewInt(2000000), result)
result, err = evaluatePostfixExpr(context, []string{"4", "3", "*"}) result, err = evaluatePostfixExpr(context, []string{"4", "3", "*"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, float64(12), result) assert.Equal(t, big.NewInt(12000000), result)
result, err = evaluatePostfixExpr(context, []string{"6", "2", "/"}) result, err = evaluatePostfixExpr(context, []string{"6", "2", "/"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, float64(3), result) assert.Equal(t, big.NewInt(3000000), result)
result, err = evaluatePostfixExpr(context, []string{"1", "2", "3", "*", "+", "4", "2", "/", "-"}) result, err = evaluatePostfixExpr(context, []string{"1", "2", "3", "*", "+", "4", "2", "/", "-"})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, float64(5), result) assert.Equal(t, big.NewInt(5000000), result)
} }
func TestEvaluatePostfixExpr_InvalidExpression(t *testing.T) { func TestEvaluatePostfixExpr_InvalidExpression(t *testing.T) {
@@ -179,6 +180,18 @@ func TestEvaluateBeancountAmountExpression_ValidExpression(t *testing.T) {
result, err = evaluateBeancountAmountExpression(context, "(((2+3)))*(((((-5+7)))))") result, err = evaluateBeancountAmountExpression(context, "(((2+3)))*(((((-5+7)))))")
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "10.00", result) assert.Equal(t, "10.00", result)
result, err = evaluateBeancountAmountExpression(context, "3.5+0.1")
assert.Nil(t, err)
assert.Equal(t, "3.60", result)
result, err = evaluateBeancountAmountExpression(context, "3.55+0.11")
assert.Nil(t, err)
assert.Equal(t, "3.66", result)
result, err = evaluateBeancountAmountExpression(context, "3.555+0.111")
assert.Nil(t, err)
assert.Equal(t, "3.66", result)
} }
func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) { func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
@@ -213,4 +226,10 @@ func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
_, err = evaluateBeancountAmountExpression(context, "1)*(2") _, err = evaluateBeancountAmountExpression(context, "1)*(2")
assert.Equal(t, errs.ErrInvalidAmountExpression, err) assert.Equal(t, errs.ErrInvalidAmountExpression, err)
_, err = evaluateBeancountAmountExpression(context, "0.abcd+1")
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
_, err = evaluateBeancountAmountExpression(context, "0.1234567+1")
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
} }