143 lines
3.9 KiB
Go
Raw Normal View History

2025-04-18 17:17:23 +08:00
package db
import (
"admin/internal/errcode"
"admin/internal/global"
"admin/lib/xlog"
"fmt"
mysqlDriver "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"sync"
"time"
)
var (
globalTables []any
locker sync.Mutex
)
func RegisterTableModels(models ...any) {
locker.Lock()
defer locker.Unlock()
globalTables = append(globalTables, models...)
}
func NewDB(dbType, dbAddr, dbName, dbUser, dbPass string) (db *gorm.DB, err error) {
switch dbType {
case "sqlite":
db, err = gorm.Open(sqlite.Open(dbName+".db"), &gorm.Config{})
if err != nil {
return nil, err
}
case "mysql":
dsn := fmt.Sprintf("%v:%v@tcp(%v)/%v?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr, dbName)
dsnWithoutDB := fmt.Sprintf("%v:%v@tcp(%v)/?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr)
db, err = createDBAndGuaranteeMigrate(dsnWithoutDB, dsn, globalTables)
}
global.GLOB_DB = db
return db, nil
}
func createDBAndGuaranteeMigrate(dsnWithoutDb, dsn string, tables []any) (*gorm.DB, error) {
mysqlDriverConf, err := mysqlDriver.ParseDSN(dsn)
if err != nil {
return nil, fmt.Errorf("parse dsn:%v error:%v", dsn, err)
}
dbName := mysqlDriverConf.DBName
_, err = tryCreateDB(dbName, dsnWithoutDb)
if err != nil {
xlog.Fatalf(err)
return nil, err
}
driverConf := mysql.Config{
DSN: dsn,
DontSupportRenameColumn: true,
//SkipInitializeWithVersion: false, // 根据数据库版本自动配置
}
dialector := mysql.New(driverConf)
//slowLogger := logger.New(
// syslog.New(xlog.GetGlobalWriter(), "\n", syslog.LstdFlags),
// logger.Config{
// // 设定慢查询时间阈值为 默认值200 * time.Millisecond
// SlowThreshold: 200 * time.Millisecond,
// // 设置日志级别
// LogLevel: logger.Warn,
// Colorful: true,
// },
//)
db, err := gorm.Open(dialector, &gorm.Config{
Logger: &gormLogger{},
PrepareStmt: false, // 关闭缓存sql语句功能因为后续use db会报错这个缓存会无限存储可能导致内存泄露
//SkipDefaultTransaction: true, // 跳过默认事务
})
if err != nil {
return nil, fmt.Errorf("failed to connect to mysql:%v", err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(50)
sqlDB.SetConnMaxIdleTime(time.Minute * 5)
sqlDB.SetConnMaxLifetime(time.Minute * 10)
if len(tables) > 0 {
err = autoMigrate(db, tables...)
if err != nil {
xlog.Fatalf(err)
return nil, fmt.Errorf("automigrate error:%v", err)
}
}
//addMetricsCollection(db, strconv.Itoa(int(serverId)), dbName)
return db, nil
}
func tryCreateDB(dbName, dsn string) (string, error) {
driverConf := mysql.Config{
DSN: dsn,
DontSupportRenameColumn: true,
//SkipInitializeWithVersion: false, // 根据数据库版本自动配置
}
dialector := mysql.New(driverConf)
db, err := gorm.Open(dialector, &gorm.Config{
PrepareStmt: false, // 关闭缓存sql语句功能因为后续use db会报错这个缓存会无限存储可能导致内存泄露
//SkipDefaultTransaction: true, // 跳过默认事务
})
if err != nil {
return "", fmt.Errorf("failed to connect to mysql:%v", err)
}
// 检查数据库是否存在
var count int
db.Raw("SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?", dbName).Scan(&count)
if count == 0 {
// 数据库不存在,创建它
2025-04-22 15:46:48 +08:00
sql := fmt.Sprintf("create database if not exists `%s` default charset utf8mb4 collate utf8mb4_unicode_ci",
2025-04-18 17:17:23 +08:00
dbName)
if e := db.Exec(sql).Error; e != nil {
return "", fmt.Errorf("failed to create database:%v", e)
}
}
sqlDb, _ := db.DB()
sqlDb.Close()
return dbName, nil
}
func autoMigrate(db *gorm.DB, tables ...interface{}) error {
// 这个函数是在InitConn之后调用的
// 初始化表
if err := db.AutoMigrate(tables...); err != nil {
return errcode.New(errcode.DBError, "failed to init tables", err)
}
return nil
}