243 lines
6.6 KiB
Go
243 lines
6.6 KiB
Go
package beancount
|
|
|
|
import (
|
|
"math/big"
|
|
"strings"
|
|
|
|
"github.com/mayswind/ezbookkeeping/pkg/core"
|
|
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
|
"github.com/mayswind/ezbookkeeping/pkg/log"
|
|
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
|
)
|
|
|
|
const maxAllowedDecimalCount = 6
|
|
const normalizeFactor = int64(1000000)
|
|
const normalizedDecimalsMaxZeroString = "000000"
|
|
const normalizedNumberToAmountFactor = int64(10000) // 1000000 / 100
|
|
|
|
var operatorPriority = map[rune]int{
|
|
'+': 1,
|
|
'-': 1,
|
|
'*': 2,
|
|
'/': 2,
|
|
}
|
|
|
|
func normalizeNumber(textualNumber string) (*big.Int, error) {
|
|
decimalSeparatorPos := strings.Index(textualNumber, ".")
|
|
|
|
if decimalSeparatorPos < 0 {
|
|
result := big.NewInt(0)
|
|
_, ok := result.SetString(textualNumber+normalizedDecimalsMaxZeroString, 10)
|
|
|
|
if !ok {
|
|
return nil, errs.ErrAmountInvalid
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
integer := utils.SubString(textualNumber, 0, decimalSeparatorPos)
|
|
decimals := utils.SubString(textualNumber, decimalSeparatorPos+1, len(textualNumber))
|
|
|
|
if len(decimals) > maxAllowedDecimalCount {
|
|
return nil, errs.ErrAmountInvalid
|
|
}
|
|
|
|
paddedDecimals := utils.SubString(decimals+normalizedDecimalsMaxZeroString, 0, maxAllowedDecimalCount)
|
|
result := big.NewInt(0)
|
|
_, ok := result.SetString(integer+paddedDecimals, 10)
|
|
|
|
if !ok {
|
|
return nil, errs.ErrAmountInvalid
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func denormalizeNumberToTextualAmount(num *big.Int) string {
|
|
result := big.NewInt(0).Add(num, big.NewInt(0)) // make a copy of num
|
|
result = result.Div(result, big.NewInt(normalizedNumberToAmountFactor))
|
|
return utils.FormatAmount(result.Int64())
|
|
}
|
|
|
|
func toPostfixExprTokens(ctx core.Context, expr string) ([]string, error) {
|
|
finalTokens := make([]string, 0)
|
|
operatorStack := make([]rune, 0)
|
|
currentNumberBuilder := strings.Builder{}
|
|
isLastTokenOperator := true
|
|
|
|
expr = strings.ReplaceAll(expr, " ", "")
|
|
|
|
for i := 0; i < len(expr); i++ {
|
|
ch := rune(expr[i])
|
|
|
|
// number
|
|
if '0' <= ch && ch <= '9' || ch == '.' {
|
|
currentNumberBuilder.WriteRune(ch)
|
|
continue
|
|
} else if ch == '-' && i+1 < len(expr) && '0' <= expr[i+1] && expr[i+1] <= '9' && currentNumberBuilder.Len() == 0 && isLastTokenOperator {
|
|
currentNumberBuilder.WriteRune(ch)
|
|
continue
|
|
}
|
|
|
|
// operator or parenthesis
|
|
if currentNumberBuilder.Len() > 0 {
|
|
finalTokens = append(finalTokens, currentNumberBuilder.String())
|
|
currentNumberBuilder.Reset()
|
|
isLastTokenOperator = false
|
|
}
|
|
|
|
switch ch {
|
|
case '+', '-', '*', '/':
|
|
if ch == '-' && isLastTokenOperator {
|
|
currentNumberBuilder.WriteRune(ch)
|
|
continue
|
|
}
|
|
|
|
for len(operatorStack) > 0 {
|
|
topOperator := operatorStack[len(operatorStack)-1]
|
|
|
|
if topOperator == '(' {
|
|
break
|
|
}
|
|
|
|
if operatorPriority[topOperator] >= operatorPriority[ch] {
|
|
finalTokens = append(finalTokens, string(topOperator))
|
|
operatorStack = operatorStack[:len(operatorStack)-1]
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
|
|
operatorStack = append(operatorStack, ch)
|
|
isLastTokenOperator = true
|
|
case '(':
|
|
operatorStack = append(operatorStack, ch)
|
|
isLastTokenOperator = true
|
|
case ')':
|
|
hasLeftParenthesis := false
|
|
|
|
for len(operatorStack) > 0 {
|
|
topOperator := operatorStack[len(operatorStack)-1]
|
|
operatorStack = operatorStack[:len(operatorStack)-1]
|
|
|
|
if topOperator == '(' {
|
|
hasLeftParenthesis = true
|
|
break
|
|
}
|
|
|
|
finalTokens = append(finalTokens, string(topOperator))
|
|
}
|
|
|
|
if !hasLeftParenthesis {
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.toPostfixExprTokens] cannot parse expression \"%s\", because missing left parenthesis", expr)
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
|
|
isLastTokenOperator = false
|
|
default:
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.toPostfixExprTokens] cannot parse expression \"%s\", because containing unknown token \"%c\"", expr, ch)
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
}
|
|
|
|
if currentNumberBuilder.Len() > 0 {
|
|
finalTokens = append(finalTokens, currentNumberBuilder.String())
|
|
}
|
|
|
|
for len(operatorStack) > 0 {
|
|
topOperator := operatorStack[len(operatorStack)-1]
|
|
operatorStack = operatorStack[:len(operatorStack)-1]
|
|
|
|
if topOperator == '(' {
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.toPostfixExprTokens] cannot parse expression \"%s\", because missing right parenthesis", expr)
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
|
|
finalTokens = append(finalTokens, string(topOperator))
|
|
}
|
|
|
|
return finalTokens, nil
|
|
}
|
|
|
|
func evaluatePostfixExpr(ctx core.Context, tokens []string) (*big.Int, error) {
|
|
stack := make([]*big.Int, 0)
|
|
|
|
for i := 0; i < len(tokens); i++ {
|
|
token := tokens[i]
|
|
|
|
switch token {
|
|
case "+", "-", "*", "/": // operators
|
|
if len(stack) < 2 {
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because not enough operands", strings.Join(tokens, " "))
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
|
|
// pop the top two operands
|
|
b := stack[len(stack)-1]
|
|
stack = stack[:len(stack)-1]
|
|
|
|
a := stack[len(stack)-1]
|
|
stack = stack[:len(stack)-1]
|
|
|
|
// evaluate the operation
|
|
result := big.NewInt(0)
|
|
switch token {
|
|
case "+":
|
|
result.Add(a, b)
|
|
case "-":
|
|
result.Sub(a, b)
|
|
case "*":
|
|
result.Mul(a, b)
|
|
result.Div(result, big.NewInt(normalizeFactor))
|
|
case "/":
|
|
if b.Int64() == 0 {
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because division by zero", strings.Join(tokens, " "))
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
result.Mul(a, big.NewInt(normalizeFactor))
|
|
result.Div(result, b)
|
|
}
|
|
|
|
// push the result back to the stack
|
|
stack = append(stack, result)
|
|
default: // operands
|
|
normalizedNum, err := normalizeNumber(token)
|
|
|
|
if err != nil {
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because containing invalid number", strings.Join(tokens, " "))
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
|
|
stack = append(stack, normalizedNum)
|
|
}
|
|
}
|
|
|
|
if len(stack) != 1 {
|
|
log.Warnf(ctx, "[beancount_amount_expression_evaluator.evaluatePostfixExpr] cannot evaluate expression \"%s\", because missing operator", strings.Join(tokens, " "))
|
|
return nil, errs.ErrInvalidAmountExpression
|
|
}
|
|
|
|
return stack[0], nil
|
|
}
|
|
|
|
func evaluateBeancountAmountExpression(ctx core.Context, expr string) (string, error) {
|
|
if expr == "" {
|
|
return "", nil
|
|
}
|
|
|
|
postfixExprTokens, err := toPostfixExprTokens(ctx, expr)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
result, err := evaluatePostfixExpr(ctx, postfixExprTokens)
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return denormalizeNumberToTextualAmount(result), nil
|
|
}
|