149 lines
3.8 KiB
Go
149 lines
3.8 KiB
Go
package datastore
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/lib/pq"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"xorm.io/xorm"
|
|
|
|
"github.com/mayswind/ezbookkeeping/pkg/errs"
|
|
"github.com/mayswind/ezbookkeeping/pkg/settings"
|
|
)
|
|
|
|
// DataStoreContainer contains all data storages
|
|
type DataStoreContainer struct {
|
|
UserStore *DataStore
|
|
TokenStore *DataStore
|
|
UserDataStore *DataStore
|
|
}
|
|
|
|
// Initialize a data storage container singleton instance
|
|
var (
|
|
Container = &DataStoreContainer{}
|
|
)
|
|
|
|
// InitializeDataStore initializes data storage according to the config
|
|
func InitializeDataStore(config *settings.Config) error {
|
|
database, err := initializeDatabase(config.DatabaseConfig)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
setDatabaseLogger(database, config)
|
|
|
|
Container.UserStore, err = NewDataStore(database)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
Container.TokenStore, err = NewDataStore(database)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
Container.UserDataStore, err = NewDataStore(database)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func initializeDatabase(dbConfig *settings.DatabaseConfig) (*Database, error) {
|
|
var connStr string
|
|
var err error
|
|
|
|
if dbConfig.DatabaseType == settings.Sqlite3DbType {
|
|
if _, err = os.Stat(dbConfig.DatabasePath); err != nil {
|
|
file, err := os.Create(dbConfig.DatabasePath)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
defer file.Close()
|
|
}
|
|
}
|
|
|
|
if dbConfig.DatabaseType == settings.MySqlDbType {
|
|
connStr, err = getMysqlConnectionString(dbConfig)
|
|
} else if dbConfig.DatabaseType == settings.PostgresDbType {
|
|
connStr, err = getPostgresConnectionString(dbConfig)
|
|
} else if dbConfig.DatabaseType == settings.Sqlite3DbType {
|
|
connStr, err = getSqlite3ConnectionString(dbConfig)
|
|
} else {
|
|
return nil, errs.ErrDatabaseTypeInvalid
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
connStrs := []string{
|
|
connStr,
|
|
}
|
|
engineGroup, err := xorm.NewEngineGroup(dbConfig.DatabaseType, connStrs, xorm.RoundRobinPolicy())
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
engineGroup.SetMaxIdleConns(int(dbConfig.MaxIdleConnection))
|
|
engineGroup.SetMaxOpenConns(int(dbConfig.MaxOpenConnection))
|
|
engineGroup.SetConnMaxLifetime(time.Duration(dbConfig.ConnectionMaxLifeTime) * time.Second)
|
|
|
|
return &Database{
|
|
databaseType: dbConfig.DatabaseType,
|
|
engineGroup: engineGroup,
|
|
}, nil
|
|
}
|
|
|
|
func setDatabaseLogger(database *Database, config *settings.Config) {
|
|
if config.EnableQueryLog {
|
|
database.engineGroup.SetLogger(NewXOrmLoggerAdapter(config.EnableQueryLog, config.LogLevel))
|
|
database.engineGroup.ShowSQL(true)
|
|
}
|
|
}
|
|
|
|
func getMysqlConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
|
|
protocol := "tcp"
|
|
|
|
if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path
|
|
protocol = "unix"
|
|
}
|
|
|
|
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=true",
|
|
dbConfig.DatabaseUser, dbConfig.DatabasePassword, protocol, dbConfig.DatabaseHost, dbConfig.DatabaseName), nil
|
|
}
|
|
|
|
func getPostgresConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
|
|
host, port, err := net.SplitHostPort(dbConfig.DatabaseHost)
|
|
|
|
if err != nil {
|
|
return "", errs.ErrDatabaseHostInvalid
|
|
}
|
|
|
|
if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path
|
|
return fmt.Sprintf("postgres://%s:%s@:%s/%s?sslmode=%s&host=%s",
|
|
url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode, host), nil
|
|
} else {
|
|
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
|
|
url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), host, port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode), nil
|
|
}
|
|
}
|
|
|
|
func getSqlite3ConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
|
|
return fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbConfig.DatabasePath), nil
|
|
}
|