package repo import ( "admin/apps/game/domain/entity" "admin/apps/game/model" "admin/apps/game/model/dto" "admin/internal/consts" "admin/internal/errcode" "admin/lib/xlog" "database/sql" "errors" "fmt" "gorm.io/gorm" "gorm.io/gorm/schema" "reflect" "strings" "time" ) var createHooks = map[string]func(projectEt *entity.Project, et dto.CommonDtoValues) error{ consts.ResourcesName_CDKey: cdKeyPreCreateHook, } type ICommonResourceRepo interface { List(project *entity.Project, params *dto.CommonListReq) (int, []*dto.CommonDtoFieldDesc, []*entity.CommonResource, error) GetById(projectEt *entity.Project, id int) ([]*dto.CommonDtoFieldDesc, *entity.CommonResource, bool, error) Create(projectEt *entity.Project, resource string, et dto.CommonDtoValues) (*entity.CommonResource, error) Edit(projectEt *entity.Project, et dto.CommonDtoValues) error Delete(projectEt *entity.Project, id int) (*entity.CommonResource, bool, error) ListPagination(whereSql string, whereArgs []any, f func(po model.IModel)) error UpdateClearDelayInvokeCreateHookFieldN(id int) error } func NewCommonResourceRepo(db *gorm.DB, poTemplate model.IModel) ICommonResourceRepo { return newCommonResourceRepoImpl(db, poTemplate) } type commonResourceRepoImpl struct { db *gorm.DB poTemplate model.IModel fieldsDescInfoFun func(project *entity.Project) []*dto.CommonDtoFieldDesc } func newCommonResourceRepoImpl(db *gorm.DB, poTemplate model.IModel) *commonResourceRepoImpl { fieldsInfo := (&entity.CommonResource{}).FromPo(poTemplate).GetDtoFieldsDescInfo return &commonResourceRepoImpl{db: db, poTemplate: poTemplate, fieldsDescInfoFun: fieldsInfo} } func (repo *commonResourceRepoImpl) List(projectEt *entity.Project, params *dto.CommonListReq) (int, []*dto.CommonDtoFieldDesc, []*entity.CommonResource, error) { pageNo := params.PageNo pageLen := params.PageLen whereConditions := params.ParsedWhereConditions.Conditions if pageNo <= 0 || pageLen <= 0 { return 0, nil, nil, errcode.New(errcode.ParamsInvalid, "page no or page len invalid:%v,%v", pageNo, pageLen) } limitStart := (pageNo - 1) * pageLen limitLen := pageLen listType := reflect.New(reflect.SliceOf(reflect.TypeOf(repo.poTemplate))) var totalCount int64 var txCount, txFind *gorm.DB var err error if len(whereConditions) <= 0 { txCount = repo.db.Model(repo.poTemplate) txFind = repo.db.Offset(limitStart).Limit(limitLen).Order("created_at desc") } else { whereSql, whereArgs := repo.parseWhereConditions2Sql(whereConditions) xlog.Debugf("list resource %v where sql:%v, args:%+v", repo.poTemplate.TableName(), whereSql, whereArgs) txCount = repo.db.Model(repo.poTemplate).Where(whereSql, whereArgs...) txFind = repo.db.Where(whereSql, whereArgs...).Offset(limitStart).Limit(limitLen).Order("created_at desc") } err = txCount.Count(&totalCount).Error if err != nil { return 0, nil, nil, errcode.New(errcode.DBError, "count resource %v error:%v", repo.poTemplate.TableName(), err) } err = txFind.Find(listType.Interface()).Error if err != nil { return 0, nil, nil, errcode.New(errcode.DBError, "list resource %v error:%v", repo.poTemplate.TableName(), err) } listType1 := listType.Elem() listLen := listType1.Len() entityList := make([]*entity.CommonResource, 0, listLen) for i := 0; i < listType1.Len(); i++ { po := listType1.Index(i).Interface().(model.IModel) et := &entity.CommonResource{} et.FromPo(po) entityList = append(entityList, et) } return int(totalCount), repo.fieldsDescInfoFun(projectEt), entityList, nil } func (repo *commonResourceRepoImpl) GetById(projectEt *entity.Project, id int) ([]*dto.CommonDtoFieldDesc, *entity.CommonResource, bool, error) { po := repo.makeEmptyPo() err := repo.db.Where("id = ?", id).First(po).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return repo.fieldsDescInfoFun(projectEt), (&entity.CommonResource{}).FromPo(repo.makeEmptyPo()), false, nil } return nil, nil, false, errcode.New(errcode.DBError, "get resource:%v by id:%v error:%v", repo.poTemplate.TableName(), id, err) } return repo.fieldsDescInfoFun(projectEt), (&entity.CommonResource{}).FromPo(po), true, nil } func (repo *commonResourceRepoImpl) Create(projectEt *entity.Project, resource string, dtoObj dto.CommonDtoValues) (*entity.CommonResource, error) { et := (&entity.CommonResource{}).FromPo(repo.makeEmptyPo()).FromDto(dtoObj) if handler, find := createHooks[resource]; find { if err := handler(projectEt, dtoObj); err != nil { return et, err } } err := repo.db.Create(et.Po).Error if err != nil { if strings.Contains(err.Error(), "Duplicate entry") || strings.Contains(err.Error(), "UNIQUE constraint") { return et, errcode.New(errcode.DBInsertDuplicate, "create resource:%v obj:%+v error:%v", repo.poTemplate.TableName(), et, err) } else { return et, errcode.New(errcode.DBError, "create resource:%v obj:%+v error:%v", repo.poTemplate.TableName(), et, err) } } return et, nil } func (repo *commonResourceRepoImpl) Edit(projectEt *entity.Project, dtoObj dto.CommonDtoValues) error { et := (&entity.CommonResource{}).FromPo(repo.makeEmptyPo()).FromDto(dtoObj) err := repo.db.Where("id=?", et.Po.GetId()).Save(et.Po).Error if err != nil { return errcode.New(errcode.DBError, "edit resource:%v obj:%+v error:%v", repo.poTemplate.TableName(), et, err) } return nil } func (repo *commonResourceRepoImpl) Delete(projectEt *entity.Project, id int) (*entity.CommonResource, bool, error) { _, et, find, err := repo.GetById(projectEt, id) if err != nil { return nil, false, err } if !find { return et, false, nil } err = repo.db.Where("id=?", id).Unscoped().Delete(repo.poTemplate).Error if err != nil { return nil, false, errcode.New(errcode.DBError, "delete resource:%v obj:%+v error:%v", repo.poTemplate.TableName(), id, err) } return et, true, nil } func (repo *commonResourceRepoImpl) ListPagination(whereSql string, whereArgs []any, f func(po model.IModel)) error { pageNo := 0 pageLen := 100 for { limitStart := pageNo * pageLen limitLen := pageLen listType := reflect.New(reflect.SliceOf(reflect.TypeOf(repo.poTemplate))) var txFind *gorm.DB var err error if len(whereSql) <= 0 { txFind = repo.db.Offset(limitStart).Limit(limitLen).Order("created_at desc") } else { txFind = repo.db.Where(whereSql, whereArgs...).Offset(limitStart).Limit(limitLen) } err = txFind.Find(listType.Interface()).Error if err != nil { return err } listType1 := listType.Elem() listLen := listType1.Len() for i := 0; i < listType1.Len(); i++ { po := listType1.Index(i).Interface().(model.IModel) f(po) } if listLen < limitLen { // 遍历完了 return nil } pageNo++ } } func (repo *commonResourceRepoImpl) UpdateClearDelayInvokeCreateHookFieldN(id int) error { repo.makeEmptyPo() err := repo.db.Model(repo.makeEmptyPo()).Where("id = ?", id).UpdateColumn("delay_invoke_create_hook", sql.NullTime{}).Error if err != nil { return err } return nil } func (repo *commonResourceRepoImpl) makeEmptyPo() model.IModel { return reflect.New(reflect.TypeOf(repo.poTemplate).Elem()).Interface().(model.IModel) } func (repo *commonResourceRepoImpl) parseWhereConditions2Sql(conditions []*dto.GetWhereCondition) (whereSql string, args []any) { namer := new(schema.NamingStrategy) to := reflect.TypeOf(repo.poTemplate).Elem() whereClause := make([]string, 0, len(conditions)) whereArgs := make([]interface{}, 0, len(conditions)) for _, cond := range conditions { for i := 0; i < to.NumField(); i++ { field := to.Field(i) if field.Name != cond.Key { continue } dbFieldName := namer.ColumnName("", field.Name) if field.Type.Name() == "Time" { if cond.Value1 == nil { cond.Value1 = time.Time{} } else { cond.Value1, _ = time.ParseInLocation("2006/01/02 15:04:05", cond.Value1.(string), time.Local) } if cond.Value2 == nil { cond.Value2 = time.Time{} } else { cond.Value2, _ = time.ParseInLocation("2006/01/02 15:04:05", cond.Value2.(string), time.Local) } } switch field.Tag.Get("where") { case "eq": if field.Tag.Get("type") == "[]string" && field.Tag.Get("multi_choice") == "true" { // eq也要查出来为空的 whereClause = append(whereClause, fmt.Sprintf("JSON_CONTAINS( `%v`, '\"%v\"', '$' ) or `%v` IS NULL", dbFieldName, cond.Value1, dbFieldName)) } else { whereClause = append(whereClause, fmt.Sprintf("`%v` = ?", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) } case "gt": whereClause = append(whereClause, fmt.Sprintf("`%v` > ?", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) case "lt": whereClause = append(whereClause, fmt.Sprintf("`%v` < ?", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) case "le": whereClause = append(whereClause, fmt.Sprintf("`%v` <= ?", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) case "ge": whereClause = append(whereClause, fmt.Sprintf("`%v` >= ?", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) case "like": whereClause = append(whereClause, fmt.Sprintf("`%v` like ?", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) case "range": t1, ok1 := cond.Value1.(time.Time) t2, ok2 := cond.Value2.(time.Time) if ok1 || ok2 { if !t1.IsZero() && !t2.IsZero() { whereClause = append(whereClause, fmt.Sprintf("`%v` >= ? and `%v` <= ?", dbFieldName, dbFieldName)) whereArgs = append(whereArgs, cond.Value1, cond.Value2) } else if !t1.IsZero() { whereClause = append(whereClause, fmt.Sprintf("`%v` >= ? ", dbFieldName)) whereArgs = append(whereArgs, cond.Value1) } else { whereClause = append(whereClause, fmt.Sprintf("`%v` <= ? ", dbFieldName)) whereArgs = append(whereArgs, cond.Value2) } } else { whereClause = append(whereClause, fmt.Sprintf("`%v` >= ? and `%v` <= ?", dbFieldName, dbFieldName)) whereArgs = append(whereArgs, cond.Value1, cond.Value2) } case "": default: panic(fmt.Errorf("unsupport where tag %v", field.Tag)) } } } whereSql = strings.Join(whereClause, " AND ") return whereSql, whereArgs }