2025-04-18 17:17:23 +08:00

143 lines
3.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package db
import (
"admin/internal/errcode"
"admin/internal/global"
"admin/lib/xlog"
"fmt"
mysqlDriver "github.com/go-sql-driver/mysql"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"sync"
"time"
)
var (
globalTables []any
locker sync.Mutex
)
func RegisterTableModels(models ...any) {
locker.Lock()
defer locker.Unlock()
globalTables = append(globalTables, models...)
}
func NewDB(dbType, dbAddr, dbName, dbUser, dbPass string) (db *gorm.DB, err error) {
switch dbType {
case "sqlite":
db, err = gorm.Open(sqlite.Open(dbName+".db"), &gorm.Config{})
if err != nil {
return nil, err
}
case "mysql":
dsn := fmt.Sprintf("%v:%v@tcp(%v)/%v?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr, dbName)
dsnWithoutDB := fmt.Sprintf("%v:%v@tcp(%v)/?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr)
db, err = createDBAndGuaranteeMigrate(dsnWithoutDB, dsn, globalTables)
}
global.GLOB_DB = db
return db, nil
}
func createDBAndGuaranteeMigrate(dsnWithoutDb, dsn string, tables []any) (*gorm.DB, error) {
mysqlDriverConf, err := mysqlDriver.ParseDSN(dsn)
if err != nil {
return nil, fmt.Errorf("parse dsn:%v error:%v", dsn, err)
}
dbName := mysqlDriverConf.DBName
_, err = tryCreateDB(dbName, dsnWithoutDb)
if err != nil {
xlog.Fatalf(err)
return nil, err
}
driverConf := mysql.Config{
DSN: dsn,
DontSupportRenameColumn: true,
//SkipInitializeWithVersion: false, // 根据数据库版本自动配置
}
dialector := mysql.New(driverConf)
//slowLogger := logger.New(
// syslog.New(xlog.GetGlobalWriter(), "\n", syslog.LstdFlags),
// logger.Config{
// // 设定慢查询时间阈值为 默认值200 * time.Millisecond
// SlowThreshold: 200 * time.Millisecond,
// // 设置日志级别
// LogLevel: logger.Warn,
// Colorful: true,
// },
//)
db, err := gorm.Open(dialector, &gorm.Config{
Logger: &gormLogger{},
PrepareStmt: false, // 关闭缓存sql语句功能因为后续use db会报错这个缓存会无限存储可能导致内存泄露
//SkipDefaultTransaction: true, // 跳过默认事务
})
if err != nil {
return nil, fmt.Errorf("failed to connect to mysql:%v", err)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(50)
sqlDB.SetConnMaxIdleTime(time.Minute * 5)
sqlDB.SetConnMaxLifetime(time.Minute * 10)
if len(tables) > 0 {
err = autoMigrate(db, tables...)
if err != nil {
xlog.Fatalf(err)
return nil, fmt.Errorf("automigrate error:%v", err)
}
}
//addMetricsCollection(db, strconv.Itoa(int(serverId)), dbName)
return db, nil
}
func tryCreateDB(dbName, dsn string) (string, error) {
driverConf := mysql.Config{
DSN: dsn,
DontSupportRenameColumn: true,
//SkipInitializeWithVersion: false, // 根据数据库版本自动配置
}
dialector := mysql.New(driverConf)
db, err := gorm.Open(dialector, &gorm.Config{
PrepareStmt: false, // 关闭缓存sql语句功能因为后续use db会报错这个缓存会无限存储可能导致内存泄露
//SkipDefaultTransaction: true, // 跳过默认事务
})
if err != nil {
return "", fmt.Errorf("failed to connect to mysql:%v", err)
}
// 检查数据库是否存在
var count int
db.Raw("SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?", dbName).Scan(&count)
if count == 0 {
// 数据库不存在,创建它
sql := fmt.Sprintf(`create database if not exists %s default charset utf8mb4 collate utf8mb4_unicode_ci`,
dbName)
if e := db.Exec(sql).Error; e != nil {
return "", fmt.Errorf("failed to create database:%v", e)
}
}
sqlDb, _ := db.DB()
sqlDb.Close()
return dbName, nil
}
func autoMigrate(db *gorm.DB, tables ...interface{}) error {
// 这个函数是在InitConn之后调用的
// 初始化表
if err := db.AutoMigrate(tables...); err != nil {
return errcode.New(errcode.DBError, "failed to init tables", err)
}
return nil
}