package db import ( "admin/internal/errcode" "admin/internal/global" "admin/lib/xlog" "fmt" mysqlDriver "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" "strings" "sync" "time" ) var ( globalTables []any locker sync.Mutex ) func RegisterTableModels(models ...any) { locker.Lock() defer locker.Unlock() globalTables = append(globalTables, models...) } func NewDB(dbType, dbAddr, dbName, dbUser, dbPass string) (db *gorm.DB, err error) { dsn := fmt.Sprintf("%v:%v@tcp(%v)/%v?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr, dbName) dsnWithoutDB := fmt.Sprintf("%v:%v@tcp(%v)/?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr) db, err = createDBAndGuaranteeMigrate(dbType, dsnWithoutDB, dsn, globalTables) if err != nil { return nil, err } global.GLOB_DB = db return db, nil } func createDBAndGuaranteeMigrate(dbType string, dsnWithoutDb, dsn string, tables []any) (*gorm.DB, error) { mysqlDriverConf, err := mysqlDriver.ParseDSN(dsn) if err != nil { return nil, fmt.Errorf("parse dsn:%v error:%v", dsn, err) } dbName := mysqlDriverConf.DBName var dialector gorm.Dialector switch dbType { case "sqlite": dialector = sqlite.Open(dbName) case "mysql": dialector = mysql.Open(dsnWithoutDb) default: panic(fmt.Errorf("unsupported db type: %v", dbType)) } _, err = tryCreateDB(dbType, dialector, dbName) if err != nil { xlog.Fatalf(err) return nil, err } switch dbType { case "sqlite": dialector = sqlite.Open(dbName) case "mysql": dialector = mysql.Open(dsn) default: panic(fmt.Errorf("unsupported db type: %v", dbType)) } //driverConf := mysql.Config{ // DSN: dsn, // DontSupportRenameColumn: true, // //SkipInitializeWithVersion: false, // 根据数据库版本自动配置 //} //dialector = mysql.New(driverConf) //slowLogger := logger.New( // syslog.New(xlog.GetGlobalWriter(), "\n", syslog.LstdFlags), // logger.Config{ // // 设定慢查询时间阈值为 默认值:200 * time.Millisecond // SlowThreshold: 200 * time.Millisecond, // // 设置日志级别 // LogLevel: logger.Warn, // Colorful: true, // }, //) db, err := gorm.Open(dialector, &gorm.Config{ Logger: &gormLogger{}, PrepareStmt: false, // 关闭缓存sql语句功能,因为后续use db会报错,这个缓存会无限存储可能导致内存泄露 //SkipDefaultTransaction: true, // 跳过默认事务 }) if err != nil { return nil, fmt.Errorf("failed to connect to mysql:%v", err) } sqlDB, _ := db.DB() sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(50) sqlDB.SetConnMaxIdleTime(time.Minute * 5) sqlDB.SetConnMaxLifetime(time.Minute * 10) if len(tables) > 0 { err = autoMigrate(db, tables...) if err != nil { xlog.Fatalf(err) return nil, fmt.Errorf("automigrate error:%v", err) } } //addMetricsCollection(db, strconv.Itoa(int(serverId)), dbName) return db, nil } func tryCreateDB(dbType string, dialector gorm.Dialector, dbName string) (string, error) { db, err := gorm.Open(dialector, &gorm.Config{ PrepareStmt: false, // 关闭缓存sql语句功能,因为后续use db会报错,这个缓存会无限存储可能导致内存泄露 //SkipDefaultTransaction: true, // 跳过默认事务 }) if err != nil { return "", fmt.Errorf("failed to connect to mysql:%v", err) } // 检查数据库是否存在 if dbType != "sqlite" { var count int db.Raw("SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?", dbName).Scan(&count) if count == 0 { // 数据库不存在,创建它 sql := fmt.Sprintf("create database if not exists `%s` default charset utf8mb4 collate utf8mb4_unicode_ci", dbName) if e := db.Exec(sql).Error; e != nil { return "", fmt.Errorf("failed to create database:%v", e) } } } sqlDb, _ := db.DB() sqlDb.Close() return dbName, nil } func autoMigrate(db *gorm.DB, tables ...interface{}) error { // 这个函数是在InitConn之后调用的 // 初始化表 for _, table := range tables { if err := db.AutoMigrate(table); err != nil { if strings.Contains(err.Error(), "there is already a table named") { continue } if strings.Contains(err.Error(), "already exists") { continue } return errcode.New(errcode.DBError, "failed to init tables:%v", err) } } return nil }