diff --git a/pkg/datastore/datastore_container.go b/pkg/datastore/datastore_container.go index e025ac86..dbb96c66 100644 --- a/pkg/datastore/datastore_container.go +++ b/pkg/datastore/datastore_container.go @@ -128,16 +128,16 @@ func getMysqlConnectionString(dbConfig *settings.DatabaseConfig) (string, error) } func getPostgresConnectionString(dbConfig *settings.DatabaseConfig) (string, error) { - host, port, err := net.SplitHostPort(dbConfig.DatabaseHost) - - if err != nil { - return "", errs.ErrDatabaseHostInvalid - } - if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path - return fmt.Sprintf("postgres://%s:%s@:%s/%s?sslmode=%s&host=%s", - url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode, host), nil + return fmt.Sprintf("postgres:///%s?sslmode=%s&host=%s&user=%s&password=%s", + dbConfig.DatabaseName, dbConfig.DatabaseSSLMode, dbConfig.DatabaseHost, url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword)), nil } else { + host, port, err := net.SplitHostPort(dbConfig.DatabaseHost) + + if err != nil { + return "", errs.ErrDatabaseHostInvalid + } + return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), host, port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode), nil } diff --git a/pkg/datastore/datastore_container_test.go b/pkg/datastore/datastore_container_test.go new file mode 100644 index 00000000..d39e2628 --- /dev/null +++ b/pkg/datastore/datastore_container_test.go @@ -0,0 +1,67 @@ +package datastore + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/mayswind/ezbookkeeping/pkg/settings" +) + +func TestGetMysqlConnectionString_TCP(t *testing.T) { + expectedValue := "username:password@tcp(1.2.3.4:3306)/dbname?charset=utf8mb4&parseTime=true" + actualValue, err := getMysqlConnectionString(&settings.DatabaseConfig{ + DatabaseType: "mysql", + DatabaseHost: "1.2.3.4:3306", + DatabaseName: "dbname", + DatabaseUser: "username", + DatabasePassword: "password", + }) + + assert.Nil(t, err) + assert.Equal(t, expectedValue, actualValue) +} + +func TestGetMysqlConnectionString_UnixSocket(t *testing.T) { + expectedValue := "username:password@unix(/path/to/mysql.sock)/dbname?charset=utf8mb4&parseTime=true" + actualValue, err := getMysqlConnectionString(&settings.DatabaseConfig{ + DatabaseType: "mysql", + DatabaseHost: "/path/to/mysql.sock", + DatabaseName: "dbname", + DatabaseUser: "username", + DatabasePassword: "password", + }) + + assert.Nil(t, err) + assert.Equal(t, expectedValue, actualValue) +} + +func TestGetPostgreSQLConnectionString_TCP(t *testing.T) { + expectedValue := "postgres://username:password@1.2.3.4:5432/dbname?sslmode=disable" + actualValue, err := getPostgresConnectionString(&settings.DatabaseConfig{ + DatabaseType: "postgres", + DatabaseHost: "1.2.3.4:5432", + DatabaseName: "dbname", + DatabaseUser: "username", + DatabasePassword: "password", + DatabaseSSLMode: "disable", + }) + + assert.Nil(t, err) + assert.Equal(t, expectedValue, actualValue) +} + +func TestGetPostgreSQLConnectionString_UnixSocket(t *testing.T) { + expectedValue := "postgres:///dbname?sslmode=disable&host=/path/to/postgres.sock&user=username&password=password" + actualValue, err := getPostgresConnectionString(&settings.DatabaseConfig{ + DatabaseType: "postgres", + DatabaseHost: "/path/to/postgres.sock", + DatabaseName: "dbname", + DatabaseUser: "username", + DatabasePassword: "password", + DatabaseSSLMode: "disable", + }) + + assert.Nil(t, err) + assert.Equal(t, expectedValue, actualValue) +}