diff --git a/cmd/initializer.go b/cmd/initializer.go index aa4d3035..ad4a8414 100644 --- a/cmd/initializer.go +++ b/cmd/initializer.go @@ -198,6 +198,10 @@ func getConfigWithoutSensitiveData(config *settings.Config) *settings.Config { if clonedConfig.ReceiptImageRecognitionLLMConfig.OpenRouterAPIKey != "" { clonedConfig.ReceiptImageRecognitionLLMConfig.OpenRouterAPIKey = "****" } + + if clonedConfig.ReceiptImageRecognitionLLMConfig.LMStudioToken != "" { + clonedConfig.ReceiptImageRecognitionLLMConfig.LMStudioToken = "****" + } } if clonedConfig.OAuth2ClientSecret != "" { diff --git a/conf/ezbookkeeping.ini b/conf/ezbookkeeping.ini index 2db8600f..b8c5892c 100644 --- a/conf/ezbookkeeping.ini +++ b/conf/ezbookkeeping.ini @@ -169,7 +169,7 @@ transaction_from_ai_image_recognition = false max_ai_recognition_picture_size = 10485760 [llm_image_recognition] -# Large Language Model (LLM) provider for receipt image recognition, supports the following types: "openai", "openai_compatible", "openrouter", "ollama", "google_ai" +# Large Language Model (LLM) provider for receipt image recognition, supports the following types: "openai", "openai_compatible", "openrouter", "ollama", "lm_studio", "google_ai" llm_provider = # For "openai" llm provider only, OpenAI API secret key, please visit https://platform.openai.com/api-keys for more information @@ -199,6 +199,15 @@ ollama_server_url = # For "ollama" llm provider only, receipt image recognition model for creating transactions from images ollama_model_id = +# For "lm_studio" llm provider only, LM Studio server url, e.g. "http://127.0.0.1:1234/" +lm_studio_server_url = + +# For "lm_studio" llm provider only, LM Studio API token, if "require authentication" is not enabled in LM Studio, leave it blank +lm_studio_token = + +# For "lm_studio" llm provider only, receipt image recognition model for creating transactions from images +lm_studio_model_id = + # For "google_ai" llm provider only, Google AI Studio API key, please visit https://aistudio.google.com/apikey for more information google_ai_api_key = diff --git a/pkg/llm/large_language_model_provider_container.go b/pkg/llm/large_language_model_provider_container.go index 89d7ac0d..a4cf4080 100644 --- a/pkg/llm/large_language_model_provider_container.go +++ b/pkg/llm/large_language_model_provider_container.go @@ -6,6 +6,7 @@ import ( "github.com/mayswind/ezbookkeeping/pkg/llm/data" "github.com/mayswind/ezbookkeeping/pkg/llm/provider" "github.com/mayswind/ezbookkeeping/pkg/llm/provider/googleai" + "github.com/mayswind/ezbookkeeping/pkg/llm/provider/lmstudio" "github.com/mayswind/ezbookkeeping/pkg/llm/provider/ollama" "github.com/mayswind/ezbookkeeping/pkg/llm/provider/openai" "github.com/mayswind/ezbookkeeping/pkg/settings" @@ -45,6 +46,8 @@ func initializeLargeLanguageModelProvider(llmConfig *settings.LLMConfig, enableR return openai.NewOpenRouterLargeLanguageModelProvider(llmConfig, enableResponseLog), nil } else if llmConfig.LLMProvider == settings.OllamaLLMProvider { return ollama.NewOllamaLargeLanguageModelProvider(llmConfig, enableResponseLog), nil + } else if llmConfig.LLMProvider == settings.LMStudioLLMProvider { + return lmstudio.NewLMStudioLargeLanguageModelProvider(llmConfig, enableResponseLog), nil } else if llmConfig.LLMProvider == settings.GoogleAILLMProvider { return googleai.NewGoogleAILargeLanguageModelProvider(llmConfig, enableResponseLog), nil } else if llmConfig.LLMProvider == "" { diff --git a/pkg/llm/provider/lmstudio/lm_studio_large_language_model_adapter.go b/pkg/llm/provider/lmstudio/lm_studio_large_language_model_adapter.go new file mode 100644 index 00000000..b3aa19fb --- /dev/null +++ b/pkg/llm/provider/lmstudio/lm_studio_large_language_model_adapter.go @@ -0,0 +1,155 @@ +package lmstudio + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "net/http" + + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/errs" + "github.com/mayswind/ezbookkeeping/pkg/llm/data" + "github.com/mayswind/ezbookkeeping/pkg/llm/provider" + "github.com/mayswind/ezbookkeeping/pkg/llm/provider/common" + "github.com/mayswind/ezbookkeeping/pkg/log" + "github.com/mayswind/ezbookkeeping/pkg/settings" +) + +const lmStudioChatPath = "api/v1/chat" + +// LMStudioLargeLanguageModelAdapter defines the structure of LM Studio large language model adapter +type LMStudioLargeLanguageModelAdapter struct { + common.HttpLargeLanguageModelAdapter + LMStudioServerURL string + LMStudioToken string + LMStudioModelID string +} + +// LMStudioChatRequest defines the structure of LM Studio chat request +type LMStudioChatRequest struct { + Model string `json:"model"` + SystemPrompt string `json:"system_prompt,omitempty"` + Input []*LMStudioChatRequestInput `json:"input"` +} + +// LMStudioChatRequestInput defines the structure of LM Studio chat request message +type LMStudioChatRequestInput struct { + Type string `json:"type"` + Content string `json:"content,omitempty"` + DataUrl string `json:"data_url,omitempty"` +} + +// LMStudioChatResponse defines the structure of LM Studio chat response +type LMStudioChatResponse struct { + Output []*LMStudioChatResponseOutput `json:"output"` +} + +// LMStudioChatResponseOutput defines the structure of LM Studio chat response message +type LMStudioChatResponseOutput struct { + Content *string `json:"content"` +} + +// BuildTextualRequest returns the http request by LM Studio large language model adapter +func (p *LMStudioLargeLanguageModelAdapter) BuildTextualRequest(c core.Context, uid int64, request *data.LargeLanguageModelRequest, responseType data.LargeLanguageModelResponseFormat) (*http.Request, error) { + requestBody, err := p.buildJsonRequestBody(c, uid, request, responseType) + + if err != nil { + return nil, err + } + + httpRequest, err := http.NewRequest("POST", p.getLMStudioRequestUrl(), bytes.NewReader(requestBody)) + + if err != nil { + return nil, err + } + + if p.LMStudioToken != "" { + httpRequest.Header.Set("Authorization", "Bearer "+p.LMStudioToken) + } + + httpRequest.Header.Set("Content-Type", "application/json") + + return httpRequest, nil +} + +// ParseTextualResponse returns the textual response by LM Studio large language model adapter +func (p *LMStudioLargeLanguageModelAdapter) ParseTextualResponse(c core.Context, uid int64, body []byte, responseType data.LargeLanguageModelResponseFormat) (*data.LargeLanguageModelTextualResponse, error) { + chatResponse := &LMStudioChatResponse{} + err := json.Unmarshal(body, &chatResponse) + + if err != nil { + log.Errorf(c, "[lm_studio_large_language_model_adapter.ParseTextualResponse] failed to parse chat response for user \"uid:%d\", because %s", uid, err.Error()) + return nil, errs.ErrFailedToRequestRemoteApi + } + + if chatResponse == nil || len(chatResponse.Output) < 1 || chatResponse.Output[0].Content == nil { + log.Errorf(c, "[lm_studio_large_language_model_adapter.ParseTextualResponse] chat response is invalid for user \"uid:%d\"", uid) + return nil, errs.ErrFailedToRequestRemoteApi + } + + textualResponse := &data.LargeLanguageModelTextualResponse{ + Content: *chatResponse.Output[0].Content, + } + + return textualResponse, nil +} + +func (p *LMStudioLargeLanguageModelAdapter) buildJsonRequestBody(c core.Context, uid int64, request *data.LargeLanguageModelRequest, responseType data.LargeLanguageModelResponseFormat) ([]byte, error) { + if p.LMStudioModelID == "" { + return nil, errs.ErrInvalidLLMModelId + } + + chatRequest := &LMStudioChatRequest{ + Model: p.LMStudioModelID, + Input: make([]*LMStudioChatRequestInput, 0, 1), + } + + if request.SystemPrompt != "" { + chatRequest.SystemPrompt = request.SystemPrompt + } + + if len(request.UserPrompt) > 0 { + if request.UserPromptType == data.LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL { + imageBase64Data := "data:" + request.UserPromptContentType + ";base64," + base64.StdEncoding.EncodeToString(request.UserPrompt) + chatRequest.Input = append(chatRequest.Input, &LMStudioChatRequestInput{ + Type: "image", + DataUrl: imageBase64Data, + }) + } else { + chatRequest.Input = append(chatRequest.Input, &LMStudioChatRequestInput{ + Type: "text", + Content: string(request.UserPrompt), + }) + } + } + + requestBodyBytes, err := json.Marshal(chatRequest) + + if err != nil { + log.Errorf(c, "[lm_studio_large_language_model_adapter.buildJsonRequestBody] failed to marshal request body for user \"uid:%d\", because %s", uid, err.Error()) + return nil, errs.ErrOperationFailed + } + + log.Debugf(c, "[lm_studio_large_language_model_adapter.buildJsonRequestBody] request body is %s", requestBodyBytes) + return requestBodyBytes, nil +} + +func (p *LMStudioLargeLanguageModelAdapter) getLMStudioRequestUrl() string { + url := p.LMStudioServerURL + + if url[len(url)-1] != '/' { + url += "/" + } + + url += lmStudioChatPath + return url +} + +// NewLMStudioLargeLanguageModelProvider creates a new LM Studio large language model provider instance +func NewLMStudioLargeLanguageModelProvider(llmConfig *settings.LLMConfig, enableResponseLog bool) provider.LargeLanguageModelProvider { + return common.NewCommonHttpLargeLanguageModelProvider(llmConfig, enableResponseLog, &LMStudioLargeLanguageModelAdapter{ + LMStudioServerURL: llmConfig.LMStudioServerURL, + LMStudioToken: llmConfig.LMStudioToken, + LMStudioModelID: llmConfig.LMStudioModelID, + }) +} diff --git a/pkg/llm/provider/lmstudio/lm_studio_large_language_model_adapter_test.go b/pkg/llm/provider/lmstudio/lm_studio_large_language_model_adapter_test.go new file mode 100644 index 00000000..8b3cab50 --- /dev/null +++ b/pkg/llm/provider/lmstudio/lm_studio_large_language_model_adapter_test.go @@ -0,0 +1,146 @@ +package lmstudio + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mayswind/ezbookkeeping/pkg/core" + "github.com/mayswind/ezbookkeeping/pkg/llm/data" +) + +func TestLMStudioLargeLanguageModelAdapter_buildJsonRequestBody_TextualUserPrompt(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{ + LMStudioModelID: "test", + } + + request := &data.LargeLanguageModelRequest{ + SystemPrompt: "You are a helpful assistant.", + UserPrompt: []byte("Hello, how are you?"), + } + + bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.Nil(t, err) + + var body map[string]interface{} + err = json.Unmarshal(bodyBytes, &body) + assert.Nil(t, err) + + assert.Equal(t, "{\"model\":\"test\",\"system_prompt\":\"You are a helpful assistant.\",\"input\":[{\"type\":\"text\",\"content\":\"Hello, how are you?\"}]}", string(bodyBytes)) +} + +func TestLMStudioLargeLanguageModelAdapter_buildJsonRequestBody_ImageUserPrompt(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{ + LMStudioModelID: "test", + } + + request := &data.LargeLanguageModelRequest{ + SystemPrompt: "What's in this image?", + UserPrompt: []byte("fakedata"), + UserPromptType: data.LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL, + UserPromptContentType: "image/png", + } + + bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.Nil(t, err) + + var body map[string]interface{} + err = json.Unmarshal(bodyBytes, &body) + assert.Nil(t, err) + + assert.Equal(t, "{\"model\":\"test\",\"system_prompt\":\"What's in this image?\",\"input\":[{\"type\":\"image\",\"data_url\":\"data:image/png;base64,ZmFrZWRhdGE=\"}]}", string(bodyBytes)) +} + +func TestLMStudioLargeLanguageModelAdapter_ParseTextualResponse_ValidJsonResponse(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{} + + response := `{ + "model_instance_id": "test", + "output": [ + { + "type": "message", + "content": "This is a test response" + } + ] + }` + + result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.Nil(t, err) + assert.Equal(t, "This is a test response", result.Content) +} + +func TestLMStudioLargeLanguageModelAdapter_ParseTextualResponse_EmptyOutputContent(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{} + + response := `{ + "model_instance_id": "test", + "output": [ + { + "type": "message", + "content": "" + } + ] + }` + + result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.Nil(t, err) + assert.Equal(t, "", result.Content) +} + +func TestLMStudioLargeLanguageModelAdapter_ParseTextualResponse_EmptyOutput(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{} + + response := `{ + "model_instance_id": "test", + "output": [] + }` + + _, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.EqualError(t, err, "failed to request third party api") +} + +func TestLMStudioLargeLanguageModelAdapter_ParseTextualResponse_NoContentFieldInOutput(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{} + + response := `{ + "model_instance_id": "test", + "output": [ + { + "type": "message" + } + ] + }` + + _, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.EqualError(t, err, "failed to request third party api") +} + +func TestLMStudioLargeLanguageModelAdapter_ParseTextualResponse_InvalidJson(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{} + + response := "error" + + _, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), data.LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON) + assert.EqualError(t, err, "failed to request third party api") +} + +func TestLMStudioLargeLanguageModelAdapter_GetOllamaRequestUrl(t *testing.T) { + adapter := &LMStudioLargeLanguageModelAdapter{ + LMStudioServerURL: "http://localhost:1234/", + } + url := adapter.getLMStudioRequestUrl() + assert.Equal(t, "http://localhost:1234/api/v1/chat", url) + + adapter = &LMStudioLargeLanguageModelAdapter{ + LMStudioServerURL: "http://localhost:1234", + } + url = adapter.getLMStudioRequestUrl() + assert.Equal(t, "http://localhost:1234/api/v1/chat", url) + + adapter = &LMStudioLargeLanguageModelAdapter{ + LMStudioServerURL: "http://example.com/lmstudio/", + } + url = adapter.getLMStudioRequestUrl() + assert.Equal(t, "http://example.com/lmstudio/api/v1/chat", url) +} diff --git a/pkg/settings/setting.go b/pkg/settings/setting.go index 67a31f1a..bf6035e2 100644 --- a/pkg/settings/setting.go +++ b/pkg/settings/setting.go @@ -73,6 +73,7 @@ const ( OpenAICompatibleLLMProvider string = "openai_compatible" OpenRouterLLMProvider string = "openrouter" OllamaLLMProvider string = "ollama" + LMStudioLLMProvider string = "lm_studio" GoogleAILLMProvider string = "google_ai" ) @@ -248,6 +249,9 @@ type LLMConfig struct { OpenRouterModelID string OllamaServerURL string OllamaModelID string + LMStudioServerURL string + LMStudioToken string + LMStudioModelID string GoogleAIAPIKey string GoogleAIModelID string LargeLanguageModelAPIRequestTimeout uint32 @@ -864,6 +868,8 @@ func loadLLMConfiguration(configFile *ini.File, sectionName string) (*LLMConfig, llmConfig.LLMProvider = OpenRouterLLMProvider } else if llmProvider == OllamaLLMProvider { llmConfig.LLMProvider = OllamaLLMProvider + } else if llmProvider == LMStudioLLMProvider { + llmConfig.LLMProvider = LMStudioLLMProvider } else if llmProvider == GoogleAILLMProvider { llmConfig.LLMProvider = GoogleAILLMProvider } else { @@ -883,6 +889,10 @@ func loadLLMConfiguration(configFile *ini.File, sectionName string) (*LLMConfig, llmConfig.OllamaServerURL = getConfigItemStringValue(configFile, sectionName, "ollama_server_url") llmConfig.OllamaModelID = getConfigItemStringValue(configFile, sectionName, "ollama_model_id") + llmConfig.LMStudioServerURL = getConfigItemStringValue(configFile, sectionName, "lm_studio_server_url") + llmConfig.LMStudioToken = getConfigItemStringValue(configFile, sectionName, "lm_studio_token") + llmConfig.LMStudioModelID = getConfigItemStringValue(configFile, sectionName, "lm_studio_model_id") + llmConfig.GoogleAIAPIKey = getConfigItemStringValue(configFile, sectionName, "google_ai_api_key") llmConfig.GoogleAIModelID = getConfigItemStringValue(configFile, sectionName, "google_ai_model_id")