add dao files

This commit is contained in:
MaysWind
2020-10-17 19:34:51 +08:00
parent 09e98bfc29
commit a7df339f47
6 changed files with 356 additions and 0 deletions
+27
View File
@@ -0,0 +1,27 @@
package datastore
import "xorm.io/xorm"
type Database struct {
*xorm.EngineGroup
}
func (db *Database) DoTranscation(fn func(sess *xorm.Session) error) (err error) {
sess := db.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
}
if err = fn(sess); err != nil {
_ = sess.Rollback()
return err
}
if err = sess.Commit(); err != nil {
return err
}
return nil
}
+47
View File
@@ -0,0 +1,47 @@
package datastore
import (
"xorm.io/xorm"
"github.com/mayswind/lab/pkg/errs"
)
type DataStore struct {
databases []*Database
}
func (s *DataStore) Choose(key int64) *Database {
return s.databases[0]
}
func (s *DataStore) Query(key int64) *xorm.Session {
return s.Choose(key).NewSession()
}
func (s *DataStore) DoTranscation(key int64, fn func(sess *xorm.Session) error) (err error) {
return s.Choose(key).DoTranscation(fn)
}
func (s *DataStore) SyncStructs(beans... interface{}) error {
var err error
for i := 0; i < len(s.databases); i++ {
err = s.databases[i].Sync2(beans...)
if err != nil {
return err
}
}
return err
}
func NewDataStore(databases... *Database) (*DataStore, error) {
if len(databases) < 1 {
return nil, errs.ErrDatabaseIsNull
}
return &DataStore{
databases: databases,
}, nil
}
+134
View File
@@ -0,0 +1,134 @@
package datastore
import (
"fmt"
"net/url"
"strings"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"xorm.io/xorm"
"github.com/mayswind/lab/pkg/errs"
"github.com/mayswind/lab/pkg/settings"
)
type DataStoreContainer struct {
UserStore *DataStore
TokenStore *DataStore
UserDataStore *DataStore
}
var (
Container = &DataStoreContainer{}
)
func InitializeDataStore(config *settings.Config) error {
database, err := initializeDatabase(config.DatabaseConfig)
if err != nil {
return err
}
setDatabaseLogger(database, config)
Container.UserStore, err = NewDataStore(database)
if err != nil {
return err
}
Container.TokenStore, err = NewDataStore(database)
if err != nil {
return err
}
Container.UserDataStore, err = NewDataStore(database)
if err != nil {
return err
}
return nil
}
func initializeDatabase(dbConfig *settings.DatabaseConfig) (*Database, error) {
var connStr string
var err error
if dbConfig.DatabaseType == settings.DBTYPE_MYSQL {
connStr, err = getMysqlConnectionString(dbConfig)
} else if dbConfig.DatabaseType == settings.DBTYPE_POSTGRES {
connStr, err = getPostgresConnectionString(dbConfig)
} else if dbConfig.DatabaseType == settings.DBTYPE_SQLITE3 {
connStr, err = getSqlite3ConnectionString(dbConfig)
} else {
return nil, errs.ErrDatabaseTypeInvalid
}
if err != nil {
return nil, err
}
connStrs := []string{
connStr,
}
engineGroup, err := xorm.NewEngineGroup(dbConfig.DatabaseType, connStrs, xorm.RoundRobinPolicy())
if err != nil {
return nil, err
}
engineGroup.SetMaxIdleConns(dbConfig.MaxIdleConnection)
engineGroup.SetMaxOpenConns(dbConfig.MaxOpenConnection)
engineGroup.SetConnMaxLifetime(time.Duration(dbConfig.ConnectionMaxLifeTime) * time.Second)
return &Database{
EngineGroup: engineGroup,
}, nil
}
func setDatabaseLogger(database *Database, config *settings.Config) {
if config.EnableQueryLog {
database.SetLogger(NewXOrmLoggerAdapter(config.EnableQueryLog, config.LogLevel))
database.ShowSQL(true)
}
}
func getMysqlConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
protocol := "tcp"
if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path
protocol = "unix"
}
return fmt.Sprintf("%s:%s@%s(%s)/%s?charset=utf8mb4&parseTime=true",
dbConfig.DatabaseUser, dbConfig.DatabasePassword, protocol, dbConfig.DatabaseHost, dbConfig.DatabaseName), nil
}
func getPostgresConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
host, port := "", ""
fields := strings.Split(dbConfig.DatabaseHost, ":")
if len(fields) != 2 {
return "", errs.ErrDatabaseHostInvalid
}
host = strings.TrimSpace(fields[0])
port = strings.TrimSpace(fields[1])
if strings.HasPrefix(dbConfig.DatabaseHost, "/") { // unix socket path
return fmt.Sprintf("postgres://%s:%s@:%s/%s?sslmode=%s&host=%s",
url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode, host), nil
} else {
return fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s",
url.QueryEscape(dbConfig.DatabaseUser), url.QueryEscape(dbConfig.DatabasePassword), host, port, dbConfig.DatabaseName, dbConfig.DatabaseSSLMode), nil
}
}
func getSqlite3ConnectionString(dbConfig *settings.DatabaseConfig) (string, error) {
return fmt.Sprintf("file:%s?cache=shared&mode=rwc", dbConfig.DatabasePath), nil
}
+88
View File
@@ -0,0 +1,88 @@
package datastore
import (
xorm "xorm.io/xorm/log"
"github.com/mayswind/lab/pkg/log"
"github.com/mayswind/lab/pkg/settings"
)
type XOrmLoggerAdapter struct {
enable bool
logLevel settings.Level
}
func (logger XOrmLoggerAdapter) Debug(v ...interface{}) {
log.SqlQuery(v...)
}
func (logger XOrmLoggerAdapter) Debugf(format string, v ...interface{}) {
log.SqlQueryf(format, v...)
}
func (logger XOrmLoggerAdapter) Info(v ...interface{}) {
log.SqlQuery(v...)
}
func (logger XOrmLoggerAdapter) Infof(format string, v ...interface{}) {
log.SqlQueryf(format, v...)
}
func (logger XOrmLoggerAdapter) Warn(v ...interface{}) {
log.SqlQuery(v...)
}
func (logger XOrmLoggerAdapter) Warnf(format string, v ...interface{}) {
log.SqlQueryf(format, v...)
}
func (logger XOrmLoggerAdapter) Error(v ...interface{}) {
log.SqlQuery(v...)
}
func (logger XOrmLoggerAdapter) Errorf(format string, v ...interface{}) {
log.SqlQueryf(format, v...)
}
func (logger XOrmLoggerAdapter) Level() xorm.LogLevel {
if logger.logLevel == settings.LOGLEVEL_DEBUG {
return xorm.LOG_DEBUG
} else if logger.logLevel == settings.LOGLEVEL_INFO {
return xorm.LOG_INFO
} else if logger.logLevel == settings.LOGLEVEL_WARN {
return xorm.LOG_WARNING
} else if logger.logLevel == settings.LOGLEVEL_ERROR {
return xorm.LOG_ERR
}
return xorm.LOG_INFO
}
func (logger XOrmLoggerAdapter) SetLevel(l xorm.LogLevel) {
if l == xorm.LOG_DEBUG {
logger.logLevel = settings.LOGLEVEL_DEBUG
} else if l == xorm.LOG_INFO {
logger.logLevel = settings.LOGLEVEL_INFO
} else if l == xorm.LOG_WARNING {
logger.logLevel = settings.LOGLEVEL_WARN
} else if l == xorm.LOG_ERR {
logger.logLevel = settings.LOGLEVEL_ERROR
}
logger.logLevel = settings.LOGLEVEL_INFO
}
func (logger XOrmLoggerAdapter) ShowSQL(show ...bool) {
logger.enable = len(show) > 0 && show[0]
}
func (logger XOrmLoggerAdapter) IsShowSQL() bool {
return logger.enable
}
func NewXOrmLoggerAdapter(showSql bool, logLevel settings.Level) xorm.Logger {
return XOrmLoggerAdapter{
enable: showSql,
logLevel: logLevel,
}
}