mirror of
https://github.com/mayswind/ezbookkeeping.git
synced 2026-05-17 08:14:25 +08:00
support Google AI LLM provider
This commit is contained in:
@@ -172,7 +172,7 @@ transaction_from_ai_image_recognition = false
|
||||
max_ai_recognition_picture_size = 10485760
|
||||
|
||||
[llm_image_recognition]
|
||||
# Large Language Model (LLM) provider for receipt image recognition, supports the following types: "openai", "openai_compatible", "openrouter", "ollama"
|
||||
# Large Language Model (LLM) provider for receipt image recognition, supports the following types: "openai", "openai_compatible", "openrouter", "ollama", "google_ai"
|
||||
llm_provider =
|
||||
|
||||
# For "openai" llm provider only, OpenAI API secret key, please visit https://platform.openai.com/api-keys for more information
|
||||
@@ -202,6 +202,12 @@ ollama_server_url =
|
||||
# For "ollama" llm provider only, receipt image recognition model for creating transactions from images
|
||||
ollama_model_id =
|
||||
|
||||
# For "google_ai" llm provider only, Google AI Studio API key, please visit https://aistudio.google.com/apikey for more information
|
||||
google_ai_api_key =
|
||||
|
||||
# For "google_ai" llm provider only, receipt image recognition model for creating transactions from images
|
||||
google_ai_model_id =
|
||||
|
||||
# Requesting large language model api timeout (0 - 4294967295 milliseconds)
|
||||
# Set to 0 to disable timeout for requesting large language model api, default is 60000 (60 seconds)
|
||||
request_timeout = 60000
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/google_ai"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/googleai"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/ollama"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/openai"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
@@ -46,7 +46,7 @@ func initializeLargeLanguageModelProvider(llmConfig *settings.LLMConfig) (provid
|
||||
} else if llmConfig.LLMProvider == settings.OllamaLLMProvider {
|
||||
return ollama.NewOllamaLargeLanguageModelProvider(llmConfig), nil
|
||||
} else if llmConfig.LLMProvider == settings.GoogleAILLMProvider {
|
||||
return google_ai.NewGoogleAILargeLanguageModelProvider(llmConfig), nil
|
||||
return googleai.NewGoogleAILargeLanguageModelProvider(llmConfig), nil
|
||||
} else if llmConfig.LLMProvider == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
package googleai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/data"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/provider/common"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||
)
|
||||
|
||||
const googleAIGenerateContentAPIFormat = "https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent"
|
||||
|
||||
// GoogleAILargeLanguageModelAdapter defines the structure of Google AI large language model adapter
|
||||
type GoogleAILargeLanguageModelAdapter struct {
|
||||
common.HttpLargeLanguageModelAdapter
|
||||
GoogleAIAPIKey string
|
||||
GoogleAIModelID string
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentRequest defines the structure of Google AI generate content request
|
||||
type GoogleAIGenerateContentRequest struct {
|
||||
Contents []*GoogleAIGenerateContentRequestContent `json:"contents"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentRequestContent defines the structure of Google AI generate content request content
|
||||
type GoogleAIGenerateContentRequestContent struct {
|
||||
Parts []*GoogleAIGenerateContentRequestContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentRequestContentPart defines the structure of Google AI generate content request content part
|
||||
type GoogleAIGenerateContentRequestContentPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GoogleAIGenerateContentRequestInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentRequestInlineData defines the structure of Google AI generate content request inline data
|
||||
type GoogleAIGenerateContentRequestInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentResponse defines the structure of Google AI generate content response
|
||||
type GoogleAIGenerateContentResponse struct {
|
||||
Candidates []*GoogleAIGenerateContentResponseCandidate `json:"candidates"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentResponseCandidate defines the structure of Google AI generate content response candidate
|
||||
type GoogleAIGenerateContentResponseCandidate struct {
|
||||
Content *GoogleAIGenerateContentResponseContent `json:"content"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentResponseContent defines the structure of Google AI generate content response content
|
||||
type GoogleAIGenerateContentResponseContent struct {
|
||||
Part []*GoogleAIGenerateContentResponseContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
// GoogleAIGenerateContentResponseContentPart defines the structure of Google AI generate content response content part
|
||||
type GoogleAIGenerateContentResponseContentPart struct {
|
||||
Text *string `json:"text"`
|
||||
}
|
||||
|
||||
// BuildTextualRequest returns the http request by Google AI large language model adapter
|
||||
func (p *GoogleAILargeLanguageModelAdapter) BuildTextualRequest(c core.Context, uid int64, request *data.LargeLanguageModelRequest, responseType data.LargeLanguageModelResponseFormat) (*http.Request, error) {
|
||||
requestBody, err := p.buildJsonRequestBody(c, uid, request, responseType)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestUrl := fmt.Sprintf(googleAIGenerateContentAPIFormat, p.GoogleAIModelID)
|
||||
httpRequest, err := http.NewRequest("POST", requestUrl, bytes.NewReader(requestBody))
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpRequest.Header.Set("Content-Type", "application/json")
|
||||
httpRequest.Header.Set("X-goog-api-key", p.GoogleAIAPIKey)
|
||||
|
||||
return httpRequest, nil
|
||||
}
|
||||
|
||||
// ParseTextualResponse returns the textual response by Google AI large language model adapter
|
||||
func (p *GoogleAILargeLanguageModelAdapter) ParseTextualResponse(c core.Context, uid int64, body []byte, responseType data.LargeLanguageModelResponseFormat) (*data.LargeLanguageModelTextualResponse, error) {
|
||||
generateContentResponse := &GoogleAIGenerateContentResponse{}
|
||||
err := json.Unmarshal(body, &generateContentResponse)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[google_ai_large_language_model_adapter.ParseTextualResponse] failed to parse generate content response for user \"uid:%d\", because %s", uid, err.Error())
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
if generateContentResponse == nil || generateContentResponse.Candidates == nil || len(generateContentResponse.Candidates) < 1 ||
|
||||
generateContentResponse.Candidates[0].Content == nil || len(generateContentResponse.Candidates[0].Content.Part) < 1 ||
|
||||
generateContentResponse.Candidates[0].Content.Part[0].Text == nil {
|
||||
log.Errorf(c, "[google_ai_large_language_model_adapter.ParseTextualResponse] generate content response is invalid for user \"uid:%d\"", uid)
|
||||
return nil, errs.ErrFailedToRequestRemoteApi
|
||||
}
|
||||
|
||||
textualResponse := &data.LargeLanguageModelTextualResponse{
|
||||
Content: *generateContentResponse.Candidates[0].Content.Part[0].Text,
|
||||
}
|
||||
|
||||
return textualResponse, nil
|
||||
}
|
||||
|
||||
func (p *GoogleAILargeLanguageModelAdapter) buildJsonRequestBody(c core.Context, uid int64, request *data.LargeLanguageModelRequest, responseType data.LargeLanguageModelResponseFormat) ([]byte, error) {
|
||||
if p.GoogleAIModelID == "" {
|
||||
return nil, errs.ErrInvalidLLMModelId
|
||||
}
|
||||
|
||||
generateContentRequest := &GoogleAIGenerateContentRequest{
|
||||
Contents: []*GoogleAIGenerateContentRequestContent{
|
||||
{
|
||||
Parts: make([]*GoogleAIGenerateContentRequestContentPart, 0, 2),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if request.SystemPrompt != "" {
|
||||
generateContentRequest.Contents[0].Parts = append(generateContentRequest.Contents[0].Parts, &GoogleAIGenerateContentRequestContentPart{
|
||||
Text: request.SystemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
if len(request.UserPrompt) > 0 {
|
||||
if request.UserPromptType == data.LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL {
|
||||
imageBase64Data := base64.StdEncoding.EncodeToString(request.UserPrompt)
|
||||
generateContentRequest.Contents[0].Parts = append(generateContentRequest.Contents[0].Parts, &GoogleAIGenerateContentRequestContentPart{
|
||||
InlineData: &GoogleAIGenerateContentRequestInlineData{
|
||||
MimeType: request.UserPromptContentType,
|
||||
Data: imageBase64Data,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
generateContentRequest.Contents[0].Parts = append(generateContentRequest.Contents[0].Parts, &GoogleAIGenerateContentRequestContentPart{
|
||||
Text: string(request.UserPrompt),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
requestBodyBytes, err := json.Marshal(generateContentRequest)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf(c, "[google_ai_large_language_model_adapter.buildJsonRequestBody] failed to marshal request body for user \"uid:%d\", because %s", uid, err.Error())
|
||||
return nil, errs.ErrOperationFailed
|
||||
}
|
||||
|
||||
log.Debugf(c, "[google_ai_large_language_model_adapter.buildJsonRequestBody] request body is %s", requestBodyBytes)
|
||||
return requestBodyBytes, nil
|
||||
}
|
||||
|
||||
// NewGoogleAILargeLanguageModelProvider creates a new Google AI large language model provider instance
|
||||
func NewGoogleAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) provider.LargeLanguageModelProvider {
|
||||
return common.NewCommonHttpLargeLanguageModelProvider(&GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIAPIKey: llmConfig.GoogleAIAPIKey,
|
||||
GoogleAIModelID: llmConfig.GoogleAIModelID,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package googleai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||
"github.com/mayswind/ezbookkeeping/pkg/llm/data"
|
||||
)
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_buildJsonRequestBody_TextualUserPrompt(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
request := &data.LargeLanguageModelRequest{
|
||||
SystemPrompt: "You are a helpful assistant.",
|
||||
UserPrompt: []byte("Hello, how are you?"),
|
||||
}
|
||||
|
||||
bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.Unmarshal(bodyBytes, &body)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "{\"contents\":[{\"parts\":[{\"text\":\"You are a helpful assistant.\"},{\"text\":\"Hello, how are you?\"}]}]}", string(bodyBytes))
|
||||
}
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_buildJsonRequestBody_ImageUserPrompt(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
request := &data.LargeLanguageModelRequest{
|
||||
SystemPrompt: "What's in this image?",
|
||||
UserPrompt: []byte("fakedata"),
|
||||
UserPromptType: data.LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL,
|
||||
UserPromptContentType: "image/png",
|
||||
}
|
||||
|
||||
bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.Unmarshal(bodyBytes, &body)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Equal(t, "{\"contents\":[{\"parts\":[{\"text\":\"What's in this image?\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"ZmFrZWRhdGE=\"}}]}]}", string(bodyBytes))
|
||||
}
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_ParseTextualResponse_ValidJsonResponse(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
response := `{
|
||||
"responseId": "test-123",
|
||||
"modelVersion": "test",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 13,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 20
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "This is a test response"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "This is a test response", result.Content)
|
||||
}
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_ParseTextualResponse_EmptyResponse(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
response := `{
|
||||
"responseId": "test-123",
|
||||
"modelVersion": "test",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 13,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 20
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": ""
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "", result.Content)
|
||||
}
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_ParseTextualResponse_EmptyCandidates(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
response := `{
|
||||
"responseId": "test-123",
|
||||
"modelVersion": "test",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 13,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 20
|
||||
},
|
||||
"candidates": []
|
||||
}`
|
||||
|
||||
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.EqualError(t, err, "failed to request third party api")
|
||||
}
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_ParseTextualResponse_NoPartText(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
response := `{
|
||||
"responseId": "test-123",
|
||||
"modelVersion": "test",
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 13,
|
||||
"candidatesTokenCount": 7,
|
||||
"totalTokenCount": 20
|
||||
},
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
}
|
||||
]
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.EqualError(t, err, "failed to request third party api")
|
||||
}
|
||||
|
||||
func TestGoogleAILargeLanguageModelAdapter_ParseTextualResponse_InvalidJson(t *testing.T) {
|
||||
adapter := &GoogleAILargeLanguageModelAdapter{
|
||||
GoogleAIModelID: "test",
|
||||
}
|
||||
|
||||
response := "error"
|
||||
|
||||
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||
assert.EqualError(t, err, "failed to request third party api")
|
||||
}
|
||||
@@ -71,6 +71,7 @@ const (
|
||||
OpenAICompatibleLLMProvider string = "openai_compatible"
|
||||
OpenRouterLLMProvider string = "openrouter"
|
||||
OllamaLLMProvider string = "ollama"
|
||||
GoogleAILLMProvider string = "google_ai"
|
||||
)
|
||||
|
||||
// Uuid generator types
|
||||
@@ -231,6 +232,8 @@ type LLMConfig struct {
|
||||
OpenRouterModelID string
|
||||
OllamaServerURL string
|
||||
OllamaModelID string
|
||||
GoogleAIAPIKey string
|
||||
GoogleAIModelID string
|
||||
LargeLanguageModelAPIRequestTimeout uint32
|
||||
LargeLanguageModelAPIProxy string
|
||||
LargeLanguageModelAPISkipTLSVerify bool
|
||||
@@ -818,6 +821,8 @@ func loadLLMConfiguration(configFile *ini.File, sectionName string) (*LLMConfig,
|
||||
llmConfig.LLMProvider = OpenRouterLLMProvider
|
||||
} else if llmProvider == OllamaLLMProvider {
|
||||
llmConfig.LLMProvider = OllamaLLMProvider
|
||||
} else if llmProvider == GoogleAILLMProvider {
|
||||
llmConfig.LLMProvider = GoogleAILLMProvider
|
||||
} else {
|
||||
return nil, errs.ErrInvalidLLMProvider
|
||||
}
|
||||
@@ -835,6 +840,9 @@ func loadLLMConfiguration(configFile *ini.File, sectionName string) (*LLMConfig,
|
||||
llmConfig.OllamaServerURL = getConfigItemStringValue(configFile, sectionName, "ollama_server_url")
|
||||
llmConfig.OllamaModelID = getConfigItemStringValue(configFile, sectionName, "ollama_model_id")
|
||||
|
||||
llmConfig.GoogleAIAPIKey = getConfigItemStringValue(configFile, sectionName, "google_ai_api_key")
|
||||
llmConfig.GoogleAIModelID = getConfigItemStringValue(configFile, sectionName, "google_ai_model_id")
|
||||
|
||||
llmConfig.LargeLanguageModelAPIProxy = getConfigItemStringValue(configFile, sectionName, "proxy", "system")
|
||||
llmConfig.LargeLanguageModelAPIRequestTimeout = getConfigItemUint32Value(configFile, sectionName, "request_timeout", defaultLargeLanguageModelAPIRequestTimeout)
|
||||
llmConfig.LargeLanguageModelAPISkipTLSVerify = getConfigItemBoolValue(configFile, sectionName, "skip_tls_verify", false)
|
||||
|
||||
Reference in New Issue
Block a user