use integers to calculate formulas for beancount amount formula
This commit is contained in:
@@ -1,15 +1,20 @@
|
||||
package beancount
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
||||
)
|
||||
|
||||
const maxAllowedDecimalCount = 6
|
||||
const normalizeFactor = int64(1000000)
|
||||
const normalizedDecimalsMaxZeroString = "000000"
|
||||
const normalizedNumberToAmountFactor = int64(10000) // 1000000 / 100
|
||||
|
||||
var operatorPriority = map[rune]int{
|
||||
'+': 1,
|
||||
'-': 1,
|
||||
@@ -17,6 +22,44 @@ var operatorPriority = map[rune]int{
|
||||
'/': 2,
|
||||
}
|
||||
|
||||
func normalizeNumber(textualNumber string) (*big.Int, error) {
|
||||
decimalSeparatorPos := strings.Index(textualNumber, ".")
|
||||
|
||||
if decimalSeparatorPos < 0 {
|
||||
result := big.NewInt(0)
|
||||
_, ok := result.SetString(textualNumber+normalizedDecimalsMaxZeroString, 10)
|
||||
|
||||
if !ok {
|
||||
return nil, errs.ErrAmountInvalid
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
integer := utils.SubString(textualNumber, 0, decimalSeparatorPos)
|
||||
decimals := utils.SubString(textualNumber, decimalSeparatorPos+1, len(textualNumber))
|
||||
|
||||
if len(decimals) > maxAllowedDecimalCount {
|
||||
return nil, errs.ErrAmountInvalid
|
||||
}
|
||||
|
||||
paddedDecimals := utils.SubString(decimals+normalizedDecimalsMaxZeroString, 0, maxAllowedDecimalCount)
|
||||
result := big.NewInt(0)
|
||||
_, ok := result.SetString(integer+paddedDecimals, 10)
|
||||
|
||||
if !ok {
|
||||
return nil, errs.ErrAmountInvalid
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func denormalizeNumberToTextualAmount(num *big.Int) string {
|
||||
result := big.NewInt(0).Add(num, big.NewInt(0)) // make a copy of num
|
||||
result = result.Div(result, big.NewInt(normalizedNumberToAmountFactor))
|
||||
return utils.FormatAmount(result.Int64())
|
||||
}
|
||||
|
||||
func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
|
||||
finalTokens := make([]string, 0)
|
||||
operatorStack := make([]rune, 0)
|
||||
@@ -117,8 +160,8 @@ func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
|
||||
return finalTokens, nil
|
||||
}
|
||||
|
||||
func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
|
||||
stack := make([]float64, 0)
|
||||
func evaluatePostfixExpr(ctx core.Context, tokens []string) (*big.Int, error) {
|
||||
stack := make([]*big.Int, 0)
|
||||
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
token := tokens[i]
|
||||
@@ -127,7 +170,7 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
|
||||
case "+", "-", "*", "/": // operators
|
||||
if len(stack) < 2 {
|
||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " "))
|
||||
return 0, errs.ErrInvalidAmountExpression
|
||||
return nil, errs.ErrInvalidAmountExpression
|
||||
}
|
||||
|
||||
// pop the top two operands
|
||||
@@ -138,39 +181,41 @@ func evaluatePostfixExpr(ctx core.Context, tokens []string) (float64, error) {
|
||||
stack = stack[:len(stack)-1]
|
||||
|
||||
// evaluate the operation
|
||||
var result float64
|
||||
result := big.NewInt(0)
|
||||
switch token {
|
||||
case "+":
|
||||
result = a + b
|
||||
result.Add(a, b)
|
||||
case "-":
|
||||
result = a - b
|
||||
result.Sub(a, b)
|
||||
case "*":
|
||||
result = a * b
|
||||
result.Mul(a, b)
|
||||
result.Div(result, big.NewInt(normalizeFactor))
|
||||
case "/":
|
||||
if b == 0 {
|
||||
if b.Int64() == 0 {
|
||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " "))
|
||||
return 0, errs.ErrInvalidAmountExpression
|
||||
return nil, errs.ErrInvalidAmountExpression
|
||||
}
|
||||
result = a / b
|
||||
result.Mul(a, big.NewInt(normalizeFactor))
|
||||
result.Div(result, b)
|
||||
}
|
||||
|
||||
// push the result back to the stack
|
||||
stack = append(stack, result)
|
||||
default: // operands
|
||||
num, err := strconv.ParseFloat(token, 64)
|
||||
normalizedNum, err := normalizeNumber(token)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " "))
|
||||
return 0, errs.ErrInvalidAmountExpression
|
||||
return nil, errs.ErrInvalidAmountExpression
|
||||
}
|
||||
|
||||
stack = append(stack, num)
|
||||
stack = append(stack, normalizedNum)
|
||||
}
|
||||
}
|
||||
|
||||
if len(stack) != 1 {
|
||||
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " "))
|
||||
return 0, errs.ErrInvalidAmountExpression
|
||||
return nil, errs.ErrInvalidAmountExpression
|
||||
}
|
||||
|
||||
return stack[0], nil
|
||||
@@ -193,5 +238,5 @@ func evaluateBeancountAmountExpression(ctx core.Context, expr string) (string, e
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.2f", result), nil
|
||||
return denormalizeNumberToTextualAmount(result), nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package beancount
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -97,23 +98,23 @@ func TestEvaluatePostfixExpr_ValidExpression(t *testing.T) {
|
||||
|
||||
result, err := evaluatePostfixExpr(context, []string{"1", "2", "+"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, float64(3), result)
|
||||
assert.Equal(t, big.NewInt(3000000), result)
|
||||
|
||||
result, err = evaluatePostfixExpr(context, []string{"5", "3", "-"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, float64(2), result)
|
||||
assert.Equal(t, big.NewInt(2000000), result)
|
||||
|
||||
result, err = evaluatePostfixExpr(context, []string{"4", "3", "*"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, float64(12), result)
|
||||
assert.Equal(t, big.NewInt(12000000), result)
|
||||
|
||||
result, err = evaluatePostfixExpr(context, []string{"6", "2", "/"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, float64(3), result)
|
||||
assert.Equal(t, big.NewInt(3000000), result)
|
||||
|
||||
result, err = evaluatePostfixExpr(context, []string{"1", "2", "3", "*", "+", "4", "2", "/", "-"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, float64(5), result)
|
||||
assert.Equal(t, big.NewInt(5000000), result)
|
||||
}
|
||||
|
||||
func TestEvaluatePostfixExpr_InvalidExpression(t *testing.T) {
|
||||
@@ -179,6 +180,18 @@ func TestEvaluateBeancountAmountExpression_ValidExpression(t *testing.T) {
|
||||
result, err = evaluateBeancountAmountExpression(context, "(((2+3)))*(((((-5+7)))))")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "10.00", result)
|
||||
|
||||
result, err = evaluateBeancountAmountExpression(context, "3.5+0.1")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "3.60", result)
|
||||
|
||||
result, err = evaluateBeancountAmountExpression(context, "3.55+0.11")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "3.66", result)
|
||||
|
||||
result, err = evaluateBeancountAmountExpression(context, "3.555+0.111")
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "3.66", result)
|
||||
}
|
||||
|
||||
func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
|
||||
@@ -213,4 +226,10 @@ func TestEvaluateBeancountAmountExpression_InvalidExpression(t *testing.T) {
|
||||
|
||||
_, err = evaluateBeancountAmountExpression(context, "1)*(2")
|
||||
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
||||
|
||||
_, err = evaluateBeancountAmountExpression(context, "0.abcd+1")
|
||||
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
||||
|
||||
_, err = evaluateBeancountAmountExpression(context, "0.1234567+1")
|
||||
assert.Equal(t, errs.ErrInvalidAmountExpression, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user