package db import ( "admin/internal/errcode" "admin/internal/global" "admin/lib/xlog" "fmt" mysqlDriver "github.com/go-sql-driver/mysql" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" "sync" "time" ) var ( globalTables []any locker sync.Mutex ) func RegisterTableModels(models ...any) { locker.Lock() defer locker.Unlock() globalTables = append(globalTables, models...) } func NewDB(dbType, dbAddr, dbName, dbUser, dbPass string) (db *gorm.DB, err error) { switch dbType { case "sqlite": db, err = gorm.Open(sqlite.Open(dbName+".db"), &gorm.Config{}) if err != nil { return nil, err } case "mysql": dsn := fmt.Sprintf("%v:%v@tcp(%v)/%v?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr, dbName) dsnWithoutDB := fmt.Sprintf("%v:%v@tcp(%v)/?charset=utf8mb4&parseTime=True&loc=Local", dbUser, dbPass, dbAddr) db, err = createDBAndGuaranteeMigrate(dsnWithoutDB, dsn, globalTables) } global.GLOB_DB = db return db, nil } func createDBAndGuaranteeMigrate(dsnWithoutDb, dsn string, tables []any) (*gorm.DB, error) { mysqlDriverConf, err := mysqlDriver.ParseDSN(dsn) if err != nil { return nil, fmt.Errorf("parse dsn:%v error:%v", dsn, err) } dbName := mysqlDriverConf.DBName _, err = tryCreateDB(dbName, dsnWithoutDb) if err != nil { xlog.Fatalf(err) return nil, err } driverConf := mysql.Config{ DSN: dsn, DontSupportRenameColumn: true, //SkipInitializeWithVersion: false, // 根据数据库版本自动配置 } dialector := mysql.New(driverConf) //slowLogger := logger.New( // syslog.New(xlog.GetGlobalWriter(), "\n", syslog.LstdFlags), // logger.Config{ // // 设定慢查询时间阈值为 默认值:200 * time.Millisecond // SlowThreshold: 200 * time.Millisecond, // // 设置日志级别 // LogLevel: logger.Warn, // Colorful: true, // }, //) db, err := gorm.Open(dialector, &gorm.Config{ Logger: &gormLogger{}, PrepareStmt: false, // 关闭缓存sql语句功能,因为后续use db会报错,这个缓存会无限存储可能导致内存泄露 //SkipDefaultTransaction: true, // 跳过默认事务 }) if err != nil { return nil, fmt.Errorf("failed to connect to mysql:%v", err) } sqlDB, _ := db.DB() sqlDB.SetMaxIdleConns(10) sqlDB.SetMaxOpenConns(50) sqlDB.SetConnMaxIdleTime(time.Minute * 5) sqlDB.SetConnMaxLifetime(time.Minute * 10) if len(tables) > 0 { err = autoMigrate(db, tables...) if err != nil { xlog.Fatalf(err) return nil, fmt.Errorf("automigrate error:%v", err) } } //addMetricsCollection(db, strconv.Itoa(int(serverId)), dbName) return db, nil } func tryCreateDB(dbName, dsn string) (string, error) { driverConf := mysql.Config{ DSN: dsn, DontSupportRenameColumn: true, //SkipInitializeWithVersion: false, // 根据数据库版本自动配置 } dialector := mysql.New(driverConf) db, err := gorm.Open(dialector, &gorm.Config{ PrepareStmt: false, // 关闭缓存sql语句功能,因为后续use db会报错,这个缓存会无限存储可能导致内存泄露 //SkipDefaultTransaction: true, // 跳过默认事务 }) if err != nil { return "", fmt.Errorf("failed to connect to mysql:%v", err) } // 检查数据库是否存在 var count int db.Raw("SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = ?", dbName).Scan(&count) if count == 0 { // 数据库不存在,创建它 sql := fmt.Sprintf(`create database if not exists %s default charset utf8mb4 collate utf8mb4_unicode_ci`, dbName) if e := db.Exec(sql).Error; e != nil { return "", fmt.Errorf("failed to create database:%v", e) } } sqlDb, _ := db.DB() sqlDb.Close() return dbName, nil } func autoMigrate(db *gorm.DB, tables ...interface{}) error { // 这个函数是在InitConn之后调用的 // 初始化表 if err := db.AutoMigrate(tables...); err != nil { return errcode.New(errcode.DBError, "failed to init tables", err) } return nil }