mirror of
https://github.com/mayswind/ezbookkeeping.git
synced 2026-05-20 09:44:26 +08:00
renamed structs and interfaces to reduce ambiguity
This commit is contained in:
+11
-11
@@ -13,8 +13,8 @@ import (
|
|||||||
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
"github.com/mayswind/ezbookkeeping/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HttpLargeLanguageModelProvider defines the structure of http large language model provider
|
// HttpLargeLanguageModelAdapter defines the structure of http large language model adapter
|
||||||
type HttpLargeLanguageModelProvider interface {
|
type HttpLargeLanguageModelAdapter interface {
|
||||||
// BuildTextualRequest returns the http request by the provider api definition
|
// BuildTextualRequest returns the http request by the provider api definition
|
||||||
BuildTextualRequest(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) (*http.Request, error)
|
BuildTextualRequest(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) (*http.Request, error)
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ type HttpLargeLanguageModelProvider interface {
|
|||||||
// CommonHttpLargeLanguageModelProvider defines the structure of common http large language model provider
|
// CommonHttpLargeLanguageModelProvider defines the structure of common http large language model provider
|
||||||
type CommonHttpLargeLanguageModelProvider struct {
|
type CommonHttpLargeLanguageModelProvider struct {
|
||||||
LargeLanguageModelProvider
|
LargeLanguageModelProvider
|
||||||
provider HttpLargeLanguageModelProvider
|
adapter HttpLargeLanguageModelAdapter
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetJsonResponse returns the json response from the OpenAI common compatible large language model provider
|
// GetJsonResponse returns the json response from the OpenAI common compatible large language model provider
|
||||||
@@ -48,10 +48,10 @@ func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context
|
|||||||
Timeout: time.Duration(currentLLMConfig.LargeLanguageModelAPIRequestTimeout) * time.Millisecond,
|
Timeout: time.Duration(currentLLMConfig.LargeLanguageModelAPIRequestTimeout) * time.Millisecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
httpRequest, err := p.provider.BuildTextualRequest(c, uid, request, responseType)
|
httpRequest, err := p.adapter.BuildTextualRequest(c, uid, request, responseType)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(c, "[http_large_language_model_provider.getTextualResponse] failed to build requests for user \"uid:%d\", because %s", uid, err.Error())
|
log.Errorf(c, "[common_http_large_language_model_provider.getTextualResponse] failed to build requests for user \"uid:%d\", because %s", uid, err.Error())
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -60,25 +60,25 @@ func (p *CommonHttpLargeLanguageModelProvider) getTextualResponse(c core.Context
|
|||||||
resp, err := client.Do(httpRequest)
|
resp, err := client.Do(httpRequest)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(c, "[http_large_language_model_provider.getTextualResponse] failed to request large language model api for user \"uid:%d\", because %s", uid, err.Error())
|
log.Errorf(c, "[common_http_large_language_model_provider.getTextualResponse] failed to request large language model api for user \"uid:%d\", because %s", uid, err.Error())
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, err := io.ReadAll(resp.Body)
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
log.Debugf(c, "[http_large_language_model_provider.getTextualResponse] response is %s", body)
|
log.Debugf(c, "[common_http_large_language_model_provider.getTextualResponse] response is %s", body)
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
log.Errorf(c, "[http_large_language_model_provider.getTextualResponse] failed to get large language model api response for user \"uid:%d\", because response code is %d", uid, resp.StatusCode)
|
log.Errorf(c, "[common_http_large_language_model_provider.getTextualResponse] failed to get large language model api response for user \"uid:%d\", because response code is %d", uid, resp.StatusCode)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.provider.ParseTextualResponse(c, uid, body, responseType)
|
return p.adapter.ParseTextualResponse(c, uid, body, responseType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCommonHttpLargeLanguageModelProvider(provider HttpLargeLanguageModelProvider) *CommonHttpLargeLanguageModelProvider {
|
func newCommonHttpLargeLanguageModelProvider(adapter HttpLargeLanguageModelAdapter) *CommonHttpLargeLanguageModelProvider {
|
||||||
return &CommonHttpLargeLanguageModelProvider{
|
return &CommonHttpLargeLanguageModelProvider{
|
||||||
provider: provider,
|
adapter: adapter,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
+16
-21
@@ -15,15 +15,15 @@ import (
|
|||||||
|
|
||||||
const ollamaChatCompletionsPath = "api/chat"
|
const ollamaChatCompletionsPath = "api/chat"
|
||||||
|
|
||||||
// OllamaLargeLanguageModelProvider defines the structure of Ollama large language model provider
|
// OllamaLargeLanguageModelAdapter defines the structure of Ollama large language model adapter
|
||||||
type OllamaLargeLanguageModelProvider struct {
|
type OllamaLargeLanguageModelAdapter struct {
|
||||||
CommonHttpLargeLanguageModelProvider
|
HttpLargeLanguageModelAdapter
|
||||||
OllamaServerURL string
|
OllamaServerURL string
|
||||||
OllamaModelID string
|
OllamaModelID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildTextualRequest returns the http request by Ollama provider
|
// BuildTextualRequest returns the http request by Ollama large language model adapter
|
||||||
func (p *OllamaLargeLanguageModelProvider) BuildTextualRequest(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) (*http.Request, error) {
|
func (p *OllamaLargeLanguageModelAdapter) BuildTextualRequest(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) (*http.Request, error) {
|
||||||
requestBody, err := p.buildJsonRequestBody(c, uid, request, responseType)
|
requestBody, err := p.buildJsonRequestBody(c, uid, request, responseType)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -41,27 +41,27 @@ func (p *OllamaLargeLanguageModelProvider) BuildTextualRequest(c core.Context, u
|
|||||||
return httpRequest, nil
|
return httpRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseTextualResponse returns the textual response by Ollama provider
|
// ParseTextualResponse returns the textual response by Ollama large language model adapter
|
||||||
func (p *OllamaLargeLanguageModelProvider) ParseTextualResponse(c core.Context, uid int64, body []byte, responseType LargeLanguageModelResponseFormat) (*LargeLanguageModelTextualResponse, error) {
|
func (p *OllamaLargeLanguageModelAdapter) ParseTextualResponse(c core.Context, uid int64, body []byte, responseType LargeLanguageModelResponseFormat) (*LargeLanguageModelTextualResponse, error) {
|
||||||
responseBody := make(map[string]any)
|
responseBody := make(map[string]any)
|
||||||
err := json.Unmarshal(body, &responseBody)
|
err := json.Unmarshal(body, &responseBody)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(c, "[ollama_large_language_model_provider.ParseTextualResponse] failed to parse response for user \"uid:%d\", because %s", uid, err.Error())
|
log.Errorf(c, "[ollama_large_language_model_adapter.ParseTextualResponse] failed to parse response for user \"uid:%d\", because %s", uid, err.Error())
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
message, ok := responseBody["message"].(map[string]any)
|
message, ok := responseBody["message"].(map[string]any)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf(c, "[ollama_large_language_model_provider.ParseTextualResponse] no message found in response for user \"uid:%d\"", uid)
|
log.Errorf(c, "[ollama_large_language_model_adapter.ParseTextualResponse] no message found in response for user \"uid:%d\"", uid)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
content, ok := message["content"].(string)
|
content, ok := message["content"].(string)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf(c, "[ollama_large_language_model_provider.ParseTextualResponse] no content found in message for user \"uid:%d\"", uid)
|
log.Errorf(c, "[ollama_large_language_model_adapter.ParseTextualResponse] no content found in message for user \"uid:%d\"", uid)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,12 +82,7 @@ func (p *OllamaLargeLanguageModelProvider) ParseTextualResponse(c core.Context,
|
|||||||
return textualResponse, nil
|
return textualResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelID returns the model id of Ollama provider
|
func (p *OllamaLargeLanguageModelAdapter) buildJsonRequestBody(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) ([]byte, error) {
|
||||||
func (p *OllamaLargeLanguageModelProvider) GetModelID() string {
|
|
||||||
return p.OllamaModelID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OllamaLargeLanguageModelProvider) buildJsonRequestBody(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) ([]byte, error) {
|
|
||||||
if p.OllamaModelID == "" {
|
if p.OllamaModelID == "" {
|
||||||
return nil, errs.ErrInvalidLLMModelId
|
return nil, errs.ErrInvalidLLMModelId
|
||||||
}
|
}
|
||||||
@@ -102,8 +97,8 @@ func (p *OllamaLargeLanguageModelProvider) buildJsonRequestBody(c core.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(request.UserPrompt) > 0 {
|
if len(request.UserPrompt) > 0 {
|
||||||
imageBase64Data := base64.StdEncoding.EncodeToString(request.UserPrompt)
|
|
||||||
if request.UserPromptType == LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL {
|
if request.UserPromptType == LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL {
|
||||||
|
imageBase64Data := base64.StdEncoding.EncodeToString(request.UserPrompt)
|
||||||
requestMessages = append(requestMessages, map[string]any{
|
requestMessages = append(requestMessages, map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "",
|
"content": "",
|
||||||
@@ -129,15 +124,15 @@ func (p *OllamaLargeLanguageModelProvider) buildJsonRequestBody(c core.Context,
|
|||||||
requestBodyBytes, err := json.Marshal(requestBody)
|
requestBodyBytes, err := json.Marshal(requestBody)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(c, "[ollama_large_language_model_provider.buildJsonRequestBody] failed to marshal request body for user \"uid:%d\", because %s", uid, err.Error())
|
log.Errorf(c, "[ollama_large_language_model_adapter.buildJsonRequestBody] failed to marshal request body for user \"uid:%d\", because %s", uid, err.Error())
|
||||||
return nil, errs.ErrOperationFailed
|
return nil, errs.ErrOperationFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf(c, "[ollama_large_language_model_provider.buildJsonRequestBody] request body is %s", requestBodyBytes)
|
log.Debugf(c, "[ollama_large_language_model_adapter.buildJsonRequestBody] request body is %s", requestBodyBytes)
|
||||||
return requestBodyBytes, nil
|
return requestBodyBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OllamaLargeLanguageModelProvider) getOllamaRequestUrl() string {
|
func (p *OllamaLargeLanguageModelAdapter) getOllamaRequestUrl() string {
|
||||||
url := p.OllamaServerURL
|
url := p.OllamaServerURL
|
||||||
|
|
||||||
if url[len(url)-1] != '/' {
|
if url[len(url)-1] != '/' {
|
||||||
@@ -150,7 +145,7 @@ func (p *OllamaLargeLanguageModelProvider) getOllamaRequestUrl() string {
|
|||||||
|
|
||||||
// NewOllamaLargeLanguageModelProvider creates a new Ollama large language model provider instance
|
// NewOllamaLargeLanguageModelProvider creates a new Ollama large language model provider instance
|
||||||
func NewOllamaLargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
func NewOllamaLargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
||||||
return newCommonHttpLargeLanguageModelProvider(&OllamaLargeLanguageModelProvider{
|
return newCommonHttpLargeLanguageModelProvider(&OllamaLargeLanguageModelAdapter{
|
||||||
OllamaServerURL: llmConfig.OllamaServerURL,
|
OllamaServerURL: llmConfig.OllamaServerURL,
|
||||||
OllamaModelID: llmConfig.OllamaModelID,
|
OllamaModelID: llmConfig.OllamaModelID,
|
||||||
})
|
})
|
||||||
+28
-28
@@ -9,8 +9,8 @@ import (
|
|||||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_buildJsonRequestBody_TextualUserPrompt(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_buildJsonRequestBody_TextualUserPrompt(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{
|
adapter := &OllamaLargeLanguageModelAdapter{
|
||||||
OllamaModelID: "test",
|
OllamaModelID: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -19,7 +19,7 @@ func TestOllamaLargeLanguageModelProvider_buildJsonRequestBody_TextualUserPrompt
|
|||||||
UserPrompt: []byte("Hello, how are you?"),
|
UserPrompt: []byte("Hello, how are you?"),
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, err := provider.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
var body map[string]interface{}
|
var body map[string]interface{}
|
||||||
@@ -29,8 +29,8 @@ func TestOllamaLargeLanguageModelProvider_buildJsonRequestBody_TextualUserPrompt
|
|||||||
assert.Equal(t, "{\"format\":\"json\",\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"Hello, how are you?\",\"role\":\"user\"}],\"model\":\"test\",\"stream\":false}", string(bodyBytes))
|
assert.Equal(t, "{\"format\":\"json\",\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"Hello, how are you?\",\"role\":\"user\"}],\"model\":\"test\",\"stream\":false}", string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_buildJsonRequestBody_ImageUserPrompt(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_buildJsonRequestBody_ImageUserPrompt(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{
|
adapter := &OllamaLargeLanguageModelAdapter{
|
||||||
OllamaModelID: "test",
|
OllamaModelID: "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ func TestOllamaLargeLanguageModelProvider_buildJsonRequestBody_ImageUserPrompt(t
|
|||||||
UserPromptType: LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL,
|
UserPromptType: LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL,
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, err := provider.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
var body map[string]interface{}
|
var body map[string]interface{}
|
||||||
@@ -50,8 +50,8 @@ func TestOllamaLargeLanguageModelProvider_buildJsonRequestBody_ImageUserPrompt(t
|
|||||||
assert.Equal(t, "{\"format\":\"json\",\"messages\":[{\"content\":\"What's in this image?\",\"role\":\"system\"},{\"content\":\"\",\"images\":[\"ZmFrZWRhdGE=\"],\"role\":\"user\"}],\"model\":\"test\",\"stream\":false}", string(bodyBytes))
|
assert.Equal(t, "{\"format\":\"json\",\"messages\":[{\"content\":\"What's in this image?\",\"role\":\"system\"},{\"content\":\"\",\"images\":[\"ZmFrZWRhdGE=\"],\"role\":\"user\"}],\"model\":\"test\",\"stream\":false}", string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_ValidJsonResponse(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_ParseTextualResponse_ValidJsonResponse(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{}
|
adapter := &OllamaLargeLanguageModelAdapter{}
|
||||||
|
|
||||||
response := `{
|
response := `{
|
||||||
"model": "test",
|
"model": "test",
|
||||||
@@ -62,13 +62,13 @@ func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_ValidJsonResponse
|
|||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
result, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "This is a test response", result.Content)
|
assert.Equal(t, "This is a test response", result.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_EmptyResponse(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_ParseTextualResponse_EmptyResponse(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{}
|
adapter := &OllamaLargeLanguageModelAdapter{}
|
||||||
|
|
||||||
response := `{
|
response := `{
|
||||||
"model": "test",
|
"model": "test",
|
||||||
@@ -79,13 +79,13 @@ func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_EmptyResponse(t *
|
|||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
result, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
assert.Equal(t, "", result.Content)
|
assert.Equal(t, "", result.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_EmptyChoices(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_ParseTextualResponse_EmptyChoices(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{}
|
adapter := &OllamaLargeLanguageModelAdapter{}
|
||||||
|
|
||||||
response := `{
|
response := `{
|
||||||
"model": "test",
|
"model": "test",
|
||||||
@@ -93,12 +93,12 @@ func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_EmptyChoices(t *t
|
|||||||
"message": {}
|
"message": {}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
_, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.EqualError(t, err, "failed to request third party api")
|
assert.EqualError(t, err, "failed to request third party api")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_NoChoiceContent(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_ParseTextualResponse_NoChoiceContent(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{}
|
adapter := &OllamaLargeLanguageModelAdapter{}
|
||||||
|
|
||||||
response := `{
|
response := `{
|
||||||
"model": "test",
|
"model": "test",
|
||||||
@@ -108,35 +108,35 @@ func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_NoChoiceContent(t
|
|||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
_, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.EqualError(t, err, "failed to request third party api")
|
assert.EqualError(t, err, "failed to request third party api")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_ParseTextualResponse_InvalidJson(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_ParseTextualResponse_InvalidJson(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{}
|
adapter := &OllamaLargeLanguageModelAdapter{}
|
||||||
|
|
||||||
response := "error"
|
response := "error"
|
||||||
|
|
||||||
_, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
assert.EqualError(t, err, "failed to request third party api")
|
assert.EqualError(t, err, "failed to request third party api")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOllamaLargeLanguageModelProvider_GetOllamaRequestUrl(t *testing.T) {
|
func TestOllamaLargeLanguageModelAdapter_GetOllamaRequestUrl(t *testing.T) {
|
||||||
provider := &OllamaLargeLanguageModelProvider{
|
adapter := &OllamaLargeLanguageModelAdapter{
|
||||||
OllamaServerURL: "http://localhost:11434/",
|
OllamaServerURL: "http://localhost:11434/",
|
||||||
}
|
}
|
||||||
url := provider.getOllamaRequestUrl()
|
url := adapter.getOllamaRequestUrl()
|
||||||
assert.Equal(t, "http://localhost:11434/api/chat", url)
|
assert.Equal(t, "http://localhost:11434/api/chat", url)
|
||||||
|
|
||||||
provider = &OllamaLargeLanguageModelProvider{
|
adapter = &OllamaLargeLanguageModelAdapter{
|
||||||
OllamaServerURL: "http://localhost:11434",
|
OllamaServerURL: "http://localhost:11434",
|
||||||
}
|
}
|
||||||
url = provider.getOllamaRequestUrl()
|
url = adapter.getOllamaRequestUrl()
|
||||||
assert.Equal(t, "http://localhost:11434/api/chat", url)
|
assert.Equal(t, "http://localhost:11434/api/chat", url)
|
||||||
|
|
||||||
provider = &OllamaLargeLanguageModelProvider{
|
adapter = &OllamaLargeLanguageModelAdapter{
|
||||||
OllamaServerURL: "http://example.com/ollama/",
|
OllamaServerURL: "http://example.com/ollama/",
|
||||||
}
|
}
|
||||||
url = provider.getOllamaRequestUrl()
|
url = adapter.getOllamaRequestUrl()
|
||||||
assert.Equal(t, "http://example.com/ollama/api/chat", url)
|
assert.Equal(t, "http://example.com/ollama/api/chat", url)
|
||||||
}
|
}
|
||||||
+8
-8
@@ -7,17 +7,17 @@ import (
|
|||||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAILargeLanguageModelProvider defines the structure of OpenAI large language model provider
|
// OpenAIOfficialChatCompletionsAPIProvider defines the structure of OpenAI official chat completions API provider
|
||||||
type OpenAILargeLanguageModelProvider struct {
|
type OpenAIOfficialChatCompletionsAPIProvider struct {
|
||||||
OpenAIChatCompletionsLargeLanguageModelProvider
|
OpenAIChatCompletionsAPIProvider
|
||||||
OpenAIAPIKey string
|
OpenAIAPIKey string
|
||||||
OpenAIModelID string
|
OpenAIModelID string
|
||||||
}
|
}
|
||||||
|
|
||||||
const openAIChatCompletionsUrl = "https://api.openai.com/v1/chat/completions"
|
const openAIChatCompletionsUrl = "https://api.openai.com/v1/chat/completions"
|
||||||
|
|
||||||
// BuildChatCompletionsHttpRequest returns the chat completions http request by OpenAI provider
|
// BuildChatCompletionsHttpRequest returns the chat completions http request by OpenAI official chat completions API provider
|
||||||
func (p *OpenAILargeLanguageModelProvider) BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error) {
|
func (p *OpenAIOfficialChatCompletionsAPIProvider) BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error) {
|
||||||
req, err := http.NewRequest("POST", openAIChatCompletionsUrl, nil)
|
req, err := http.NewRequest("POST", openAIChatCompletionsUrl, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -29,14 +29,14 @@ func (p *OpenAILargeLanguageModelProvider) BuildChatCompletionsHttpRequest(c cor
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelID returns the model id of OpenAI provider
|
// GetModelID returns the model id of OpenAI official chat completions API provider
|
||||||
func (p *OpenAILargeLanguageModelProvider) GetModelID() string {
|
func (p *OpenAIOfficialChatCompletionsAPIProvider) GetModelID() string {
|
||||||
return p.OpenAIModelID
|
return p.OpenAIModelID
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAILargeLanguageModelProvider creates a new OpenAI large language model provider instance
|
// NewOpenAILargeLanguageModelProvider creates a new OpenAI large language model provider instance
|
||||||
func NewOpenAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
func NewOpenAILargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
||||||
return newOpenAICommonChatCompletionsHttpLargeLanguageModelProvider(&OpenAILargeLanguageModelProvider{
|
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenAIOfficialChatCompletionsAPIProvider{
|
||||||
OpenAIAPIKey: llmConfig.OpenAIAPIKey,
|
OpenAIAPIKey: llmConfig.OpenAIAPIKey,
|
||||||
OpenAIModelID: llmConfig.OpenAIModelID,
|
OpenAIModelID: llmConfig.OpenAIModelID,
|
||||||
})
|
})
|
||||||
+24
-24
@@ -15,8 +15,8 @@ import (
|
|||||||
"github.com/mayswind/ezbookkeeping/pkg/log"
|
"github.com/mayswind/ezbookkeeping/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAIChatCompletionsLargeLanguageModelProvider defines the structure of OpenAI chat completions compatible large language model provider
|
// OpenAIChatCompletionsAPIProvider defines the structure of OpenAI chat completions API provider
|
||||||
type OpenAIChatCompletionsLargeLanguageModelProvider interface {
|
type OpenAIChatCompletionsAPIProvider interface {
|
||||||
// BuildChatCompletionsHttpRequest returns the chat completions http request
|
// BuildChatCompletionsHttpRequest returns the chat completions http request
|
||||||
BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error)
|
BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error)
|
||||||
|
|
||||||
@@ -24,21 +24,21 @@ type OpenAIChatCompletionsLargeLanguageModelProvider interface {
|
|||||||
GetModelID() string
|
GetModelID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAICommonChatCompletionsHttpLargeLanguageModelProvider defines the structure of OpenAI common compatible large language model provider based on chat completions api
|
// CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter defines the structure of OpenAI common compatible large language model adapter based on chat completions api
|
||||||
type OpenAICommonChatCompletionsHttpLargeLanguageModelProvider struct {
|
type CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter struct {
|
||||||
CommonHttpLargeLanguageModelProvider
|
HttpLargeLanguageModelAdapter
|
||||||
provider OpenAIChatCompletionsLargeLanguageModelProvider
|
apiProvider OpenAIChatCompletionsAPIProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildTextualRequest returns the http request by OpenAI common compatible provider
|
// BuildTextualRequest returns the http request by OpenAI common compatible adapter
|
||||||
func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) BuildTextualRequest(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) (*http.Request, error) {
|
func (p *CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter) BuildTextualRequest(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) (*http.Request, error) {
|
||||||
requestBody, err := p.buildJsonRequestBody(c, uid, request, responseType)
|
requestBody, err := p.buildJsonRequestBody(c, uid, request, responseType)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpRequest, err := p.provider.BuildChatCompletionsHttpRequest(c, uid)
|
httpRequest, err := p.apiProvider.BuildChatCompletionsHttpRequest(c, uid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -50,41 +50,41 @@ func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) BuildTextual
|
|||||||
return httpRequest, nil
|
return httpRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseTextualResponse returns the textual response by OpenAI common compatible provider
|
// ParseTextualResponse returns the textual response by OpenAI common compatible adapter
|
||||||
func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) ParseTextualResponse(c core.Context, uid int64, body []byte, responseType LargeLanguageModelResponseFormat) (*LargeLanguageModelTextualResponse, error) {
|
func (p *CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter) ParseTextualResponse(c core.Context, uid int64, body []byte, responseType LargeLanguageModelResponseFormat) (*LargeLanguageModelTextualResponse, error) {
|
||||||
responseBody := make(map[string]any)
|
responseBody := make(map[string]any)
|
||||||
err := json.Unmarshal(body, &responseBody)
|
err := json.Unmarshal(body, &responseBody)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(c, "[openai_common_compatible_large_language_model_provider.ParseTextualResponse] failed to parse response for user \"uid:%d\", because %s", uid, err.Error())
|
log.Errorf(c, "[openai_common_compatible_large_language_model_adapter.ParseTextualResponse] failed to parse response for user \"uid:%d\", because %s", uid, err.Error())
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
choices, ok := responseBody["choices"].([]any)
|
choices, ok := responseBody["choices"].([]any)
|
||||||
|
|
||||||
if !ok || len(choices) < 1 {
|
if !ok || len(choices) < 1 {
|
||||||
log.Errorf(c, "[openai_common_compatible_large_language_model_provider.ParseTextualResponse] no choices found in response for user \"uid:%d\"", uid)
|
log.Errorf(c, "[openai_common_compatible_large_language_model_adapter.ParseTextualResponse] no choices found in response for user \"uid:%d\"", uid)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
firstChoice, ok := choices[0].(map[string]any)
|
firstChoice, ok := choices[0].(map[string]any)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf(c, "[openai_common_compatible_large_language_model_provider.ParseTextualResponse] invalid choice format in response for user \"uid:%d\"", uid)
|
log.Errorf(c, "[openai_common_compatible_large_language_model_adapter.ParseTextualResponse] invalid choice format in response for user \"uid:%d\"", uid)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
message, ok := firstChoice["message"].(map[string]any)
|
message, ok := firstChoice["message"].(map[string]any)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf(c, "[openai_common_compatible_large_language_model_provider.ParseTextualResponse] no message found in choice for user \"uid:%d\"", uid)
|
log.Errorf(c, "[openai_common_compatible_large_language_model_adapter.ParseTextualResponse] no message found in choice for user \"uid:%d\"", uid)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
content, ok := message["content"].(string)
|
content, ok := message["content"].(string)
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Errorf(c, "[openai_common_compatible_large_language_model_provider.ParseTextualResponse] no content found in message for user \"uid:%d\"", uid)
|
log.Errorf(c, "[openai_common_compatible_large_language_model_adapter.ParseTextualResponse] no content found in message for user \"uid:%d\"", uid)
|
||||||
return nil, errs.ErrFailedToRequestRemoteApi
|
return nil, errs.ErrFailedToRequestRemoteApi
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,8 +105,8 @@ func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) ParseTextual
|
|||||||
return textualResponse, nil
|
return textualResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) buildJsonRequestBody(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) ([]byte, error) {
|
func (p *CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter) buildJsonRequestBody(c core.Context, uid int64, request *LargeLanguageModelRequest, responseType LargeLanguageModelResponseFormat) ([]byte, error) {
|
||||||
if p.provider.GetModelID() == "" {
|
if p.apiProvider.GetModelID() == "" {
|
||||||
return nil, errs.ErrInvalidLLMModelId
|
return nil, errs.ErrInvalidLLMModelId
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) buildJsonReq
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestBody := make(map[string]any)
|
requestBody := make(map[string]any)
|
||||||
requestBody["model"] = p.provider.GetModelID()
|
requestBody["model"] = p.apiProvider.GetModelID()
|
||||||
requestBody["stream"] = request.Stream
|
requestBody["stream"] = request.Stream
|
||||||
requestBody["messages"] = requestMessages
|
requestBody["messages"] = requestMessages
|
||||||
|
|
||||||
@@ -171,16 +171,16 @@ func (p *OpenAICommonChatCompletionsHttpLargeLanguageModelProvider) buildJsonReq
|
|||||||
requestBodyBytes, err := json.Marshal(requestBody)
|
requestBodyBytes, err := json.Marshal(requestBody)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf(c, "[openai_common_compatible_large_language_model_provider.buildJsonRequestBody] failed to marshal request body for user \"uid:%d\", because %s", uid, err.Error())
|
log.Errorf(c, "[openai_common_compatible_large_language_model_adapter.buildJsonRequestBody] failed to marshal request body for user \"uid:%d\", because %s", uid, err.Error())
|
||||||
return nil, errs.ErrOperationFailed
|
return nil, errs.ErrOperationFailed
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf(c, "[openai_common_compatible_large_language_model_provider.buildJsonRequestBody] request body is %s", requestBodyBytes)
|
log.Debugf(c, "[openai_common_compatible_large_language_model_adapter.buildJsonRequestBody] request body is %s", requestBodyBytes)
|
||||||
return requestBodyBytes, nil
|
return requestBodyBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newOpenAICommonChatCompletionsHttpLargeLanguageModelProvider(provider OpenAIChatCompletionsLargeLanguageModelProvider) LargeLanguageModelProvider {
|
func newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(apiProvider OpenAIChatCompletionsAPIProvider) LargeLanguageModelProvider {
|
||||||
return newCommonHttpLargeLanguageModelProvider(&OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
return newCommonHttpLargeLanguageModelProvider(&CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
provider: provider,
|
apiProvider: apiProvider,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package llm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/mayswind/ezbookkeeping/pkg/core"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_buildJsonRequestBody_TextualUserPrompt(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{
|
||||||
|
OpenAIModelID: "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &LargeLanguageModelRequest{
|
||||||
|
SystemPrompt: "You are a helpful assistant.",
|
||||||
|
UserPrompt: []byte("Hello, how are you?"),
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
var body map[string]interface{}
|
||||||
|
err = json.Unmarshal(bodyBytes, &body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "{\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"Hello, how are you?\",\"role\":\"user\"}],\"model\":\"test\",\"response_format\":{\"type\":\"json_object\"},\"stream\":false}", string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_buildJsonRequestBody_ImageUserPrompt(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{
|
||||||
|
OpenAIModelID: "test",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &LargeLanguageModelRequest{
|
||||||
|
SystemPrompt: "What's in this image?",
|
||||||
|
UserPrompt: []byte("fakedata"),
|
||||||
|
UserPromptType: LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := adapter.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
var body map[string]interface{}
|
||||||
|
err = json.Unmarshal(bodyBytes, &body)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "{\"messages\":[{\"content\":\"What's in this image?\",\"role\":\"system\"},{\"content\":[{\"image_url\":{\"url\":\"data:image/png;base64,ZmFrZWRhdGE=\"},\"type\":\"image_url\"}],\"role\":\"user\"}],\"model\":\"test\",\"response_format\":{\"type\":\"json_object\"},\"stream\":false}", string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_ParseTextualResponse_ValidJsonResponse(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := `{
|
||||||
|
"id": "test-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1234567890,
|
||||||
|
"model": "test",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 13,
|
||||||
|
"completion_tokens": 7,
|
||||||
|
"total_tokens": 20
|
||||||
|
},
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "This is a test response"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "This is a test response", result.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_ParseTextualResponse_EmptyResponse(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := `{
|
||||||
|
"id": "test-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
result, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, "", result.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_ParseTextualResponse_EmptyChoices(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := `{
|
||||||
|
"id": "test-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": []
|
||||||
|
}`
|
||||||
|
|
||||||
|
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.EqualError(t, err, "failed to request third party api")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_ParseTextualResponse_NoChoiceContent(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := `{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.EqualError(t, err, "failed to request third party api")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter_ParseTextualResponse_InvalidJson(t *testing.T) {
|
||||||
|
adapter := &CommonOpenAIChatCompletionsAPILargeLanguageModelAdapter{
|
||||||
|
apiProvider: &OpenAIOfficialChatCompletionsAPIProvider{},
|
||||||
|
}
|
||||||
|
|
||||||
|
response := "error"
|
||||||
|
|
||||||
|
_, err := adapter.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
||||||
|
assert.EqualError(t, err, "failed to request third party api")
|
||||||
|
}
|
||||||
@@ -1,161 +0,0 @@
|
|||||||
package llm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
|
|
||||||
"github.com/mayswind/ezbookkeeping/pkg/core"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_buildJsonRequestBody_TextualUserPrompt(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{
|
|
||||||
OpenAIModelID: "test",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request := &LargeLanguageModelRequest{
|
|
||||||
SystemPrompt: "You are a helpful assistant.",
|
|
||||||
UserPrompt: []byte("Hello, how are you?"),
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, err := provider.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
var body map[string]interface{}
|
|
||||||
err = json.Unmarshal(bodyBytes, &body)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "{\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"Hello, how are you?\",\"role\":\"user\"}],\"model\":\"test\",\"response_format\":{\"type\":\"json_object\"},\"stream\":false}", string(bodyBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_buildJsonRequestBody_ImageUserPrompt(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{
|
|
||||||
OpenAIModelID: "test",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
request := &LargeLanguageModelRequest{
|
|
||||||
SystemPrompt: "What's in this image?",
|
|
||||||
UserPrompt: []byte("fakedata"),
|
|
||||||
UserPromptType: LARGE_LANGUAGE_MODEL_REQUEST_PROMPT_TYPE_IMAGE_URL,
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyBytes, err := provider.buildJsonRequestBody(core.NewNullContext(), 0, request, LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
var body map[string]interface{}
|
|
||||||
err = json.Unmarshal(bodyBytes, &body)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "{\"messages\":[{\"content\":\"What's in this image?\",\"role\":\"system\"},{\"content\":[{\"image_url\":{\"url\":\"data:image/png;base64,ZmFrZWRhdGE=\"},\"type\":\"image_url\"}],\"role\":\"user\"}],\"model\":\"test\",\"response_format\":{\"type\":\"json_object\"},\"stream\":false}", string(bodyBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_ParseTextualResponse_ValidJsonResponse(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{},
|
|
||||||
}
|
|
||||||
|
|
||||||
response := `{
|
|
||||||
"id": "test-123",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1234567890,
|
|
||||||
"model": "test",
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 13,
|
|
||||||
"completion_tokens": 7,
|
|
||||||
"total_tokens": 20
|
|
||||||
},
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "This is a test response"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`
|
|
||||||
|
|
||||||
result, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "This is a test response", result.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_ParseTextualResponse_EmptyResponse(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{},
|
|
||||||
}
|
|
||||||
|
|
||||||
response := `{
|
|
||||||
"id": "test-123",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`
|
|
||||||
|
|
||||||
result, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
assert.Equal(t, "", result.Content)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_ParseTextualResponse_EmptyChoices(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{},
|
|
||||||
}
|
|
||||||
|
|
||||||
response := `{
|
|
||||||
"id": "test-123",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"choices": []
|
|
||||||
}`
|
|
||||||
|
|
||||||
_, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.EqualError(t, err, "failed to request third party api")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_ParseTextualResponse_NoChoiceContent(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{},
|
|
||||||
}
|
|
||||||
|
|
||||||
response := `{
|
|
||||||
"id": "chatcmpl-123",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`
|
|
||||||
|
|
||||||
_, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.EqualError(t, err, "failed to request third party api")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenAICommonChatCompletionsHttpLargeLanguageModelProvider_ParseTextualResponse_InvalidJson(t *testing.T) {
|
|
||||||
provider := &OpenAICommonChatCompletionsHttpLargeLanguageModelProvider{
|
|
||||||
provider: &OpenAILargeLanguageModelProvider{},
|
|
||||||
}
|
|
||||||
|
|
||||||
response := "error"
|
|
||||||
|
|
||||||
_, err := provider.ParseTextualResponse(core.NewNullContext(), 0, []byte(response), LARGE_LANGUAGE_MODEL_RESPONSE_FORMAT_JSON)
|
|
||||||
assert.EqualError(t, err, "failed to request third party api")
|
|
||||||
}
|
|
||||||
+9
-9
@@ -9,16 +9,16 @@ import (
|
|||||||
|
|
||||||
const openAICompatibleChatCompletionsPath = "chat/completions"
|
const openAICompatibleChatCompletionsPath = "chat/completions"
|
||||||
|
|
||||||
// OpenAICompatibleLargeLanguageModelProvider defines the structure of OpenAI compatible large language model provider
|
// OpenAICompatibleChatCompletionsAPIProvider defines the structure of OpenAI compatible chat completions API provider
|
||||||
type OpenAICompatibleLargeLanguageModelProvider struct {
|
type OpenAICompatibleChatCompletionsAPIProvider struct {
|
||||||
OpenAIChatCompletionsLargeLanguageModelProvider
|
OpenAIChatCompletionsAPIProvider
|
||||||
OpenAICompatibleBaseURL string
|
OpenAICompatibleBaseURL string
|
||||||
OpenAICompatibleAPIKey string
|
OpenAICompatibleAPIKey string
|
||||||
OpenAICompatibleModelID string
|
OpenAICompatibleModelID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildChatCompletionsHttpRequest returns the chat completions http request by OpenAI compatible provider
|
// BuildChatCompletionsHttpRequest returns the chat completions http request by OpenAI compatible chat completions API provider
|
||||||
func (p *OpenAICompatibleLargeLanguageModelProvider) BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error) {
|
func (p *OpenAICompatibleChatCompletionsAPIProvider) BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error) {
|
||||||
req, err := http.NewRequest("POST", p.getFinalChatCompletionsRequestUrl(), nil)
|
req, err := http.NewRequest("POST", p.getFinalChatCompletionsRequestUrl(), nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -32,12 +32,12 @@ func (p *OpenAICompatibleLargeLanguageModelProvider) BuildChatCompletionsHttpReq
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelID returns the model id of OpenAI compatible provider
|
// GetModelID returns the model id of OpenAI compatible chat completions API provider
|
||||||
func (p *OpenAICompatibleLargeLanguageModelProvider) GetModelID() string {
|
func (p *OpenAICompatibleChatCompletionsAPIProvider) GetModelID() string {
|
||||||
return p.OpenAICompatibleModelID
|
return p.OpenAICompatibleModelID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAICompatibleLargeLanguageModelProvider) getFinalChatCompletionsRequestUrl() string {
|
func (p *OpenAICompatibleChatCompletionsAPIProvider) getFinalChatCompletionsRequestUrl() string {
|
||||||
url := p.OpenAICompatibleBaseURL
|
url := p.OpenAICompatibleBaseURL
|
||||||
|
|
||||||
if url[len(url)-1] != '/' {
|
if url[len(url)-1] != '/' {
|
||||||
@@ -50,7 +50,7 @@ func (p *OpenAICompatibleLargeLanguageModelProvider) getFinalChatCompletionsRequ
|
|||||||
|
|
||||||
// NewOpenAICompatibleLargeLanguageModelProvider creates a new OpenAI compatible large language model provider instance
|
// NewOpenAICompatibleLargeLanguageModelProvider creates a new OpenAI compatible large language model provider instance
|
||||||
func NewOpenAICompatibleLargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
func NewOpenAICompatibleLargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
||||||
return newOpenAICommonChatCompletionsHttpLargeLanguageModelProvider(&OpenAICompatibleLargeLanguageModelProvider{
|
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenAICompatibleChatCompletionsAPIProvider{
|
||||||
OpenAICompatibleBaseURL: llmConfig.OpenAICompatibleBaseURL,
|
OpenAICompatibleBaseURL: llmConfig.OpenAICompatibleBaseURL,
|
||||||
OpenAICompatibleAPIKey: llmConfig.OpenAICompatibleAPIKey,
|
OpenAICompatibleAPIKey: llmConfig.OpenAICompatibleAPIKey,
|
||||||
OpenAICompatibleModelID: llmConfig.OpenAICompatibleModelID,
|
OpenAICompatibleModelID: llmConfig.OpenAICompatibleModelID,
|
||||||
+7
-7
@@ -6,22 +6,22 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOpenAICompatibleLargeLanguageModelProvider_GetFinalRequestUrl(t *testing.T) {
|
func TestOpenAICompatibleChatCompletionsAPIProvider_GetFinalRequestUrl(t *testing.T) {
|
||||||
provider := &OpenAICompatibleLargeLanguageModelProvider{
|
apiProvider := &OpenAICompatibleChatCompletionsAPIProvider{
|
||||||
OpenAICompatibleBaseURL: "https://api.example.com/v1/",
|
OpenAICompatibleBaseURL: "https://api.example.com/v1/",
|
||||||
}
|
}
|
||||||
url := provider.getFinalChatCompletionsRequestUrl()
|
url := apiProvider.getFinalChatCompletionsRequestUrl()
|
||||||
assert.Equal(t, "https://api.example.com/v1/chat/completions", url)
|
assert.Equal(t, "https://api.example.com/v1/chat/completions", url)
|
||||||
|
|
||||||
provider = &OpenAICompatibleLargeLanguageModelProvider{
|
apiProvider = &OpenAICompatibleChatCompletionsAPIProvider{
|
||||||
OpenAICompatibleBaseURL: "https://api.example.com/v1",
|
OpenAICompatibleBaseURL: "https://api.example.com/v1",
|
||||||
}
|
}
|
||||||
url = provider.getFinalChatCompletionsRequestUrl()
|
url = apiProvider.getFinalChatCompletionsRequestUrl()
|
||||||
assert.Equal(t, "https://api.example.com/v1/chat/completions", url)
|
assert.Equal(t, "https://api.example.com/v1/chat/completions", url)
|
||||||
|
|
||||||
provider = &OpenAICompatibleLargeLanguageModelProvider{
|
apiProvider = &OpenAICompatibleChatCompletionsAPIProvider{
|
||||||
OpenAICompatibleBaseURL: "https://example.com/api",
|
OpenAICompatibleBaseURL: "https://example.com/api",
|
||||||
}
|
}
|
||||||
url = provider.getFinalChatCompletionsRequestUrl()
|
url = apiProvider.getFinalChatCompletionsRequestUrl()
|
||||||
assert.Equal(t, "https://example.com/api/chat/completions", url)
|
assert.Equal(t, "https://example.com/api/chat/completions", url)
|
||||||
}
|
}
|
||||||
+8
-8
@@ -7,17 +7,17 @@ import (
|
|||||||
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenRouterLargeLanguageModelProvider defines the structure of OpenRouter large language model provider
|
// OpenRouterChatCompletionsAPIProvider defines the structure of OpenRouter chat completions API provider
|
||||||
type OpenRouterLargeLanguageModelProvider struct {
|
type OpenRouterChatCompletionsAPIProvider struct {
|
||||||
OpenAIChatCompletionsLargeLanguageModelProvider
|
OpenAIChatCompletionsAPIProvider
|
||||||
OpenRouterAPIKey string
|
OpenRouterAPIKey string
|
||||||
OpenRouterModelID string
|
OpenRouterModelID string
|
||||||
}
|
}
|
||||||
|
|
||||||
const openRouterChatCompletionsUrl = "https://openrouter.ai/api/v1/chat/completions"
|
const openRouterChatCompletionsUrl = "https://openrouter.ai/api/v1/chat/completions"
|
||||||
|
|
||||||
// BuildChatCompletionsHttpRequest returns the chat completions http request by OpenRouter provider
|
// BuildChatCompletionsHttpRequest returns the chat completions http request by OpenRouter chat completions API provider
|
||||||
func (p *OpenRouterLargeLanguageModelProvider) BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error) {
|
func (p *OpenRouterChatCompletionsAPIProvider) BuildChatCompletionsHttpRequest(c core.Context, uid int64) (*http.Request, error) {
|
||||||
req, err := http.NewRequest("POST", openRouterChatCompletionsUrl, nil)
|
req, err := http.NewRequest("POST", openRouterChatCompletionsUrl, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -31,14 +31,14 @@ func (p *OpenRouterLargeLanguageModelProvider) BuildChatCompletionsHttpRequest(c
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelID returns the model id of OpenRouter provider
|
// GetModelID returns the model id of OpenRouter chat completions API provider
|
||||||
func (p *OpenRouterLargeLanguageModelProvider) GetModelID() string {
|
func (p *OpenRouterChatCompletionsAPIProvider) GetModelID() string {
|
||||||
return p.OpenRouterModelID
|
return p.OpenRouterModelID
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenRouterLargeLanguageModelProvider creates a new OpenRouter large language model provider instance
|
// NewOpenRouterLargeLanguageModelProvider creates a new OpenRouter large language model provider instance
|
||||||
func NewOpenRouterLargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
func NewOpenRouterLargeLanguageModelProvider(llmConfig *settings.LLMConfig) LargeLanguageModelProvider {
|
||||||
return newOpenAICommonChatCompletionsHttpLargeLanguageModelProvider(&OpenRouterLargeLanguageModelProvider{
|
return newCommonOpenAIChatCompletionsAPILargeLanguageModelAdapter(&OpenRouterChatCompletionsAPIProvider{
|
||||||
OpenRouterAPIKey: llmConfig.OpenRouterAPIKey,
|
OpenRouterAPIKey: llmConfig.OpenRouterAPIKey,
|
||||||
OpenRouterModelID: llmConfig.OpenRouterModelID,
|
OpenRouterModelID: llmConfig.OpenRouterModelID,
|
||||||
})
|
})
|
||||||
Reference in New Issue
Block a user