add dao files
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
package datastore
|
||||
|
||||
import "xorm.io/xorm"
|
||||
|
||||
type Database struct {
|
||||
*xorm.EngineGroup
|
||||
}
|
||||
|
||||
func (db *Database) DoTranscation(fn func(sess *xorm.Session) error) (err error) {
|
||||
sess := db.NewSession()
|
||||
defer sess.Close()
|
||||
|
||||
if err = sess.Begin(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = fn(sess); err != nil {
|
||||
_ = sess.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if err = sess.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/mayswind/lab/pkg/errs"
|
||||
)
|
||||
|
||||
type DataStore struct {
|
||||
databases []*Database
|
||||
}
|
||||
|
||||
func (s *DataStore) Choose(key int64) *Database {
|
||||
return s.databases[0]
|
||||
}
|
||||
|
||||
func (s *DataStore) Query(key int64) *xorm.Session {
|
||||
return s.Choose(key).NewSession()
|
||||
}
|
||||
|
||||
func (s *DataStore) DoTranscation(key int64, fn func(sess *xorm.Session) error) (err error) {
|
||||
return s.Choose(key).DoTranscation(fn)
|
||||
}
|
||||
|
||||
func (s *DataStore) SyncStructs(beans... interface{}) error {
|
||||
var err error
|
||||
|
||||
for i := 0; i < len(s.databases); i++ {
|
||||
err = s.databases[i].Sync2(beans...)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func NewDataStore(databases... *Database) (*DataStore, error) {
|
||||
if len(databases) < 1 {
|
||||
return nil, errs.ErrDatabaseIsNull
|
||||
}
|
||||
|
||||
return &DataStore{
|
||||
databases: databases,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/mayswind/lab/pkg/errs"
|
||||
"github.com/mayswind/lab/pkg/settings"
|
||||
)
|
||||
|
||||
type DataStoreContainer struct {
|
||||
UserStore *DataStore
|
||||
TokenStore *DataStore
|
||||
UserDataStore *DataStore
|
||||
}
|
||||
|
||||
var (
|
||||
Container = &DataStoreContainer{}
|
||||
)
|
||||
|
||||
func InitializeDataStore(config *settings.Config) error {
|
||||
database, err := initializeDatabase(config.DatabaseConfig)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
setDatabaseLogger(database, config)
|
||||
|
||||
Container.UserStore, err = NewDataStore(database)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Container.TokenStore, err = NewDataStore(database)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Container.UserDataStore, err = NewDataStore(database)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initializeDatabase(dbConfig *settings.DatabaseConfig) (*Database, error) {
|
||||
var connStr string
|
||||
var err error
|
||||
|
||||
if dbConfig.DatabaseType == settings.DBTYPE_MYSQL {
|
||||
connStr, err = getMysqlConnectionString(dbConfig)
|
||||
} else if dbConfig.DatabaseType == settings.DBTYPE_POSTGRES {
|
||||
connStr, err = getPostgresConnectionString(dbConfig)
|
||||
} else if dbConfig.DatabaseType == settings.DBTYPE_SQLITE3 {
|
||||
connStr, err = getSqlite3ConnectionString(dbConfig)
|
||||
} else {
|
||||
return nil, errs.ErrDatabaseTypeInvalid
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connStrs := []string{
|
||||
connStr,
|
||||
}
|
||||
engineGroup, err := xorm.NewEngineGroup(dbConfig.DatabaseType, connStrs, xorm.RoundRobinPolicy())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
engineGroup.SetMaxIdleConns(dbConfig.MaxIdleConnection)
|
||||
engineGroup.SetMaxOpenConns(dbConfig.MaxOpenConnection)
|
||||
engineGroup.SetConnMaxLifetime(time.Duration(dbConfig.ConnectionMaxLifeTime) * time.Second)
|
||||
|
||||
return &Database{
|
||||
EngineGroup: engineGroup,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setDatabaseLogger(database *Database, config *settings.Config) {
|
||||
if config.EnableQueryLog {
|
||||
database.SetLogger(NewXOrmLoggerAdapter(config.EnableQueryLog, config.LogLevel))
|
||||
database.ShowSQL(true)
|
||||
}
|
||||
}
|
||||
|
||||
func getMysqlConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
|
||||
protocol := "tcp"
|
||||
|
||||
if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path
|
||||
protocol = "unix"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=true",
|
||||
dbConfig.DatabaseUser, dbConfig.DatabasePassword, protocol, dbConfig.DatabaseHost, dbConfig.DatabaseName), nil
|
||||
}
|
||||
|
||||
func getPostgresConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
|
||||
host, port := "", ""
|
||||
fields := strings.Split(dbConfig.DatabaseHost, ":")
|
||||
|
||||
if len(fields) != 2 {
|
||||
return "", errs.ErrDatabaseHostInvalid
|
||||
}
|
||||
|
||||
host = strings.TrimSpace(fields[0])
|
||||
port = strings.TrimSpace(fields[1])
|
||||
|
||||
if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path
|
||||
return fmt.Sprintf("postgres://%s:%s@:%s/%s?sslmode=%s&host=%s",
|
||||
url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode, host), nil
|
||||
} else {
|
||||
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
|
||||
url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), host, port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode), nil
|
||||
}
|
||||
}
|
||||
|
||||
func getSqlite3ConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
|
||||
return fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbConfig.DatabasePath), nil
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
xorm "xorm.io/xorm/log"
|
||||
|
||||
"github.com/mayswind/lab/pkg/log"
|
||||
"github.com/mayswind/lab/pkg/settings"
|
||||
)
|
||||
|
||||
type XOrmLoggerAdapter struct {
|
||||
enable bool
|
||||
logLevel settings.Level
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Debug(v ...interface{}) {
|
||||
log.SqlQuery(v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Debugf(format string, v ...interface{}) {
|
||||
log.SqlQueryf(format, v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Info(v ...interface{}) {
|
||||
log.SqlQuery(v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Infof(format string, v ...interface{}) {
|
||||
log.SqlQueryf(format, v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Warn(v ...interface{}) {
|
||||
log.SqlQuery(v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Warnf(format string, v ...interface{}) {
|
||||
log.SqlQueryf(format, v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Error(v ...interface{}) {
|
||||
log.SqlQuery(v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Errorf(format string, v ...interface{}) {
|
||||
log.SqlQueryf(format, v...)
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) Level() xorm.LogLevel {
|
||||
if logger.logLevel == settings.LOGLEVEL_DEBUG {
|
||||
return xorm.LOG_DEBUG
|
||||
} else if logger.logLevel == settings.LOGLEVEL_INFO {
|
||||
return xorm.LOG_INFO
|
||||
} else if logger.logLevel == settings.LOGLEVEL_WARN {
|
||||
return xorm.LOG_WARNING
|
||||
} else if logger.logLevel == settings.LOGLEVEL_ERROR {
|
||||
return xorm.LOG_ERR
|
||||
}
|
||||
|
||||
return xorm.LOG_INFO
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) SetLevel(l xorm.LogLevel) {
|
||||
if l == xorm.LOG_DEBUG {
|
||||
logger.logLevel = settings.LOGLEVEL_DEBUG
|
||||
} else if l == xorm.LOG_INFO {
|
||||
logger.logLevel = settings.LOGLEVEL_INFO
|
||||
} else if l == xorm.LOG_WARNING {
|
||||
logger.logLevel = settings.LOGLEVEL_WARN
|
||||
} else if l == xorm.LOG_ERR {
|
||||
logger.logLevel = settings.LOGLEVEL_ERROR
|
||||
}
|
||||
|
||||
logger.logLevel = settings.LOGLEVEL_INFO
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) ShowSQL(show ...bool) {
|
||||
logger.enable = len(show) > 0 && show[0]
|
||||
}
|
||||
|
||||
func (logger XOrmLoggerAdapter) IsShowSQL() bool {
|
||||
return logger.enable
|
||||
}
|
||||
|
||||
func NewXOrmLoggerAdapter(showSql bool, logLevel settings.Level) xorm.Logger {
|
||||
return XOrmLoggerAdapter{
|
||||
enable: showSql,
|
||||
logLevel: logLevel,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user