diff --git a/pkg/api/transactions.go b/pkg/api/transactions.go index 282bac3b..dd3968ae 100644 --- a/pkg/api/transactions.go +++ b/pkg/api/transactions.go @@ -161,6 +161,13 @@ func (a *TransactionsApi) TransactionCreateHandler(c *core.Context) (interface{} return nil, errs.NewIncompleteOrIncorrectSubmissionError(err) } + tagIds, err := utils.StringArrayToInt64Array(transactionCreateReq.TagIds) + + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionCreateHandler] parse tag ids failed, because %s", err.Error()) + return nil, errs.ErrTransactionTagIdInvalid + } + if transactionCreateReq.Type < models.TRANSACTION_TYPE_MODIFY_BALANCE || transactionCreateReq.Type > models.TRANSACTION_TYPE_TRANSFER { log.WarnfWithRequestId(c, "[transactions.TransactionCreateHandler] transaction type is invalid") return nil, errs.ErrTransactionTypeInvalid @@ -187,7 +194,7 @@ func (a *TransactionsApi) TransactionCreateHandler(c *core.Context) (interface{} uid := c.GetCurrentUid() transaction := a.createNewTransactionModel(uid, &transactionCreateReq) - err = a.transactions.CreateTransaction(transaction, transactionCreateReq.TagIds) + err = a.transactions.CreateTransaction(transaction, tagIds) if err != nil { log.ErrorfWithRequestId(c, "[transactions.TransactionCreateHandler] failed to create transaction \"id:%d\" for user \"uid:%d\", because %s", transaction.TransactionId, uid, err.Error()) @@ -211,6 +218,13 @@ func (a *TransactionsApi) TransactionModifyHandler(c *core.Context) (interface{} return nil, errs.NewIncompleteOrIncorrectSubmissionError(err) } + tagIds, err := utils.StringArrayToInt64Array(transactionModifyReq.TagIds) + + if err != nil { + log.WarnfWithRequestId(c, "[transactions.TransactionModifyHandler] parse tag ids failed, because %s", err.Error()) + return nil, errs.ErrTransactionTagIdInvalid + } + uid := c.GetCurrentUid() transaction, err := a.transactions.GetTransactionByTransactionId(uid, transactionModifyReq.Id) @@ -227,8 +241,8 @@ func (a *TransactionsApi) TransactionModifyHandler(c *core.Context) (interface{} } transactionTagIds := allTransactionTagIds[transaction.TransactionId] - addTransactionTagIds := utils.Int64SliceMinus(transactionModifyReq.TagIds, transactionTagIds) - removeTransactionTagIds := utils.Int64SliceMinus(transactionTagIds, transactionModifyReq.TagIds) + addTransactionTagIds := utils.Int64SliceMinus(tagIds, transactionTagIds) + removeTransactionTagIds := utils.Int64SliceMinus(transactionTagIds, tagIds) newTransaction := &models.Transaction{ TransactionId: transaction.TransactionId, diff --git a/pkg/models/transaction.go b/pkg/models/transaction.go index 93534ae6..039df7f8 100644 --- a/pkg/models/transaction.go +++ b/pkg/models/transaction.go @@ -40,21 +40,21 @@ type TransactionCreateRequest struct { DestinationAccountId int64 `json:"destinationAccountId,string" binding:"required,min=1"` SourceAmount int64 `json:"sourceAmount" binding:"min=-99999999999,max=99999999999"` DestinationAmount int64 `json:"destinationAmount" binding:"min=-99999999999,max=99999999999"` - TagIds []int64 `json:"tagIds,string"` + TagIds []string `json:"tagIds"` Comment string `json:"comment" binding:"max=255"` } // TransactionModifyRequest represents all parameters of transaction modification request type TransactionModifyRequest struct { - Id int64 `json:"id,string" binding:"required,min=1"` - CategoryId int64 `json:"categoryId,string"` - Time int64 `json:"time" binding:"required,min=1"` - SourceAccountId int64 `json:"sourceAccountId,string" binding:"required,min=1"` - DestinationAccountId int64 `json:"destinationAccountId,string" binding:"required,min=1"` - SourceAmount int64 `json:"sourceAmount" binding:"min=-99999999999,max=99999999999"` - DestinationAmount int64 `json:"destinationAmount" binding:"min=-99999999999,max=99999999999"` - TagIds []int64 `json:"tagIds,string"` - Comment string `json:"comment" binding:"max=255"` + Id int64 `json:"id,string" binding:"required,min=1"` + CategoryId int64 `json:"categoryId,string"` + Time int64 `json:"time" binding:"required,min=1"` + SourceAccountId int64 `json:"sourceAccountId,string" binding:"required,min=1"` + DestinationAccountId int64 `json:"destinationAccountId,string" binding:"required,min=1"` + SourceAmount int64 `json:"sourceAmount" binding:"min=-99999999999,max=99999999999"` + DestinationAmount int64 `json:"destinationAmount" binding:"min=-99999999999,max=99999999999"` + TagIds []string `json:"tagIds"` + Comment string `json:"comment" binding:"max=255"` } // TransactionListByMaxTimeRequest represents all parameters of transaction listing by max time request @@ -92,7 +92,7 @@ type TransactionInfoResponse struct { DestinationAccountId int64 `json:"destinationAccountId,string"` SourceAmount int64 `json:"sourceAmount"` DestinationAmount int64 `json:"destinationAmount"` - TagIds []int64 `json:"tagIds,string"` + TagIds []string `json:"tagIds"` Comment string `json:"comment"` } @@ -114,7 +114,7 @@ func (c *Transaction) ToTransactionInfoResponse(tagIds []int64) *TransactionInfo DestinationAccountId: c.DestinationAccountId, SourceAmount: c.SourceAmount, DestinationAmount: c.DestinationAmount, - TagIds: tagIds, + TagIds: utils.Int64ArrayToStringArray(tagIds), Comment: c.Comment, } } diff --git a/pkg/utils/converter.go b/pkg/utils/converter.go index f5a5f4a8..5eb8512a 100644 --- a/pkg/utils/converter.go +++ b/pkg/utils/converter.go @@ -29,11 +29,39 @@ func Int64ToString(num int64) string { return strconv.FormatInt(num, 10) } +// Int64ArrayToStringArray returns a array of textual representation of these numbers +func Int64ArrayToStringArray(num []int64) []string { + ret := make([]string, 0, len(num)) + + for i := 0; i < len(num); i++ { + ret = append(ret, Int64ToString(num[i])) + } + + return ret +} + // StringToInt64 parses a textual representation of the number to int64 func StringToInt64(str string) (int64, error) { return strconv.ParseInt(str, 10, 64) } +// StringArrayToInt64Array parses a series textual representations of the numbers to int64 array +func StringArrayToInt64Array(strs []string) ([]int64, error) { + ret := make([]int64, 0, len(strs)) + + for i := 0; i < len(strs); i++ { + val, err := StringToInt64(strs[i]) + + if err != nil { + return nil, err + } + + ret = append(ret, val) + } + + return ret, nil +} + // StringTryToInt64 parses a textual representation of the number to int64 if str is valid, // or returns the default value func StringTryToInt64(str string, defaultValue int64) int64 {