164 lines
4.3 KiB
Go
164 lines
4.3 KiB
Go
package db
|
||
|
||
import (
|
||
"admin/internal/errcode"
|
||
"admin/internal/global"
|
||
"admin/lib/xlog"
|
||
"fmt"
|
||
mysqlDriver "github.com/go-sql-driver/mysql"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
var (
|
||
globalTables []any
|
||
locker sync.Mutex
|
||
)
|
||
|
||
func RegisterTableModels(models ...any) {
|
||
locker.Lock()
|
||
defer locker.Unlock()
|
||
globalTables = append(globalTables, models...)
|
||
}
|
||
|
||
func NewDB(dbType, dbAddr, dbName, dbUser, dbPass string) (db *gorm.DB, err error) {
|
||
dsn := fmt.Sprintf("%v:%v@tcp(%v)/%v?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr, dbName)
|
||
dsnWithoutDB := fmt.Sprintf("%v:%v@tcp(%v)/?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr)
|
||
db, err = createDBAndGuaranteeMigrate(dbType, dsnWithoutDB, dsn, globalTables)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
global.GLOB_DB = db
|
||
return db, nil
|
||
}
|
||
|
||
func createDBAndGuaranteeMigrate(dbType string, dsnWithoutDb, dsn string, tables []any) (*gorm.DB, error) {
|
||
mysqlDriverConf, err := mysqlDriver.ParseDSN(dsn)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("parse dsn:%v error:%v", dsn, err)
|
||
}
|
||
dbName := mysqlDriverConf.DBName
|
||
|
||
var dialector gorm.Dialector
|
||
switch dbType {
|
||
case "sqlite":
|
||
dialector = sqlite.Open(dbName)
|
||
case "mysql":
|
||
dialector = mysql.Open(dsnWithoutDb)
|
||
default:
|
||
panic(fmt.Errorf("unsupported db type: %v", dbType))
|
||
}
|
||
|
||
_, err = tryCreateDB(dbType, dialector, dbName)
|
||
if err != nil {
|
||
xlog.Fatalf(err)
|
||
return nil, err
|
||
}
|
||
|
||
switch dbType {
|
||
case "sqlite":
|
||
dialector = sqlite.Open(dbName)
|
||
case "mysql":
|
||
dialector = mysql.Open(dsn)
|
||
default:
|
||
panic(fmt.Errorf("unsupported db type: %v", dbType))
|
||
}
|
||
|
||
//driverConf := mysql.Config{
|
||
// DSN: dsn,
|
||
// DontSupportRenameColumn: true,
|
||
// //SkipInitializeWithVersion: false, // 根据数据库版本自动配置
|
||
//}
|
||
//dialector = mysql.New(driverConf)
|
||
|
||
//slowLogger := logger.New(
|
||
// syslog.New(xlog.GetGlobalWriter(), "\n", syslog.LstdFlags),
|
||
// logger.Config{
|
||
// // 设定慢查询时间阈值为 默认值:200 * time.Millisecond
|
||
// SlowThreshold: 200 * time.Millisecond,
|
||
// // 设置日志级别
|
||
// LogLevel: logger.Warn,
|
||
// Colorful: true,
|
||
// },
|
||
//)
|
||
|
||
db, err := gorm.Open(dialector, &gorm.Config{
|
||
Logger: &gormLogger{},
|
||
PrepareStmt: false, // 关闭缓存sql语句功能,因为后续use db会报错,这个缓存会无限存储可能导致内存泄露
|
||
//SkipDefaultTransaction: true, // 跳过默认事务
|
||
})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to connect to mysql:%v", err)
|
||
}
|
||
|
||
sqlDB, _ := db.DB()
|
||
sqlDB.SetMaxIdleConns(10)
|
||
sqlDB.SetMaxOpenConns(50)
|
||
sqlDB.SetConnMaxIdleTime(time.Minute * 5)
|
||
sqlDB.SetConnMaxLifetime(time.Minute * 10)
|
||
|
||
if len(tables) > 0 {
|
||
err = autoMigrate(db, tables...)
|
||
if err != nil {
|
||
xlog.Fatalf(err)
|
||
return nil, fmt.Errorf("automigrate error:%v", err)
|
||
}
|
||
}
|
||
|
||
//addMetricsCollection(db, strconv.Itoa(int(serverId)), dbName)
|
||
return db, nil
|
||
}
|
||
|
||
func tryCreateDB(dbType string, dialector gorm.Dialector, dbName string) (string, error) {
|
||
|
||
db, err := gorm.Open(dialector, &gorm.Config{
|
||
PrepareStmt: false, // 关闭缓存sql语句功能,因为后续use db会报错,这个缓存会无限存储可能导致内存泄露
|
||
//SkipDefaultTransaction: true, // 跳过默认事务
|
||
})
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to connect to mysql:%v", err)
|
||
}
|
||
|
||
// 检查数据库是否存在
|
||
if dbType != "sqlite" {
|
||
var count int
|
||
db.Raw("SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?", dbName).Scan(&count)
|
||
if count == 0 {
|
||
// 数据库不存在,创建它
|
||
sql := fmt.Sprintf("create database if not exists `%s` default charset utf8mb4 collate utf8mb4_unicode_ci",
|
||
dbName)
|
||
if e := db.Exec(sql).Error; e != nil {
|
||
return "", fmt.Errorf("failed to create database:%v", e)
|
||
}
|
||
}
|
||
}
|
||
|
||
sqlDb, _ := db.DB()
|
||
sqlDb.Close()
|
||
|
||
return dbName, nil
|
||
}
|
||
|
||
func autoMigrate(db *gorm.DB, tables ...interface{}) error {
|
||
// 这个函数是在InitConn之后调用的
|
||
|
||
// 初始化表
|
||
for _, table := range tables {
|
||
if err := db.AutoMigrate(table); err != nil {
|
||
if strings.Contains(err.Error(), "there is already a table named") {
|
||
continue
|
||
}
|
||
if strings.Contains(err.Error(), "already exists") {
|
||
continue
|
||
}
|
||
return errcode.New(errcode.DBError, "failed to init tables:%v", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|