package web import ( "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" "reflect" "strconv" "strings" ) type routerGroupGin struct { group *gin.RouterGroup newContextFun func() Context } func (group *routerGroupGin) Use(middlewares ...HandlerFunc) { group.group.Use(getGinHandlerFun(group.newContextFun, middlewares...)...) } func (group *routerGroupGin) Group(path string, handlers ...HandlerFunc) routeGroupInterface { ginGroup := group.group.Group(path, getGinHandlerFun(group.newContextFun, handlers...)...) group1 := &routerGroupGin{group: ginGroup, newContextFun: group.newContextFun} return group1 } func (group *routerGroupGin) Get(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { group.group.GET(path, getGinHandlerFun(group.newContextFun, handlers...)...) } else { group.group.GET(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...) } return group } func (group *routerGroupGin) Post(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { group.group.POST(path, getGinHandlerFun(group.newContextFun, handlers...)...) } else { group.group.POST(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...) } return group } func (group *routerGroupGin) Put(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { group.group.PUT(path, getGinHandlerFun(group.newContextFun, handlers...)...) } else { group.group.PUT(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...) } return group } func (group *routerGroupGin) Delete(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { group.group.DELETE(path, getGinHandlerFun(group.newContextFun, handlers...)...) } else { group.group.DELETE(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...) } return group } type routerGin struct { engine *gin.Engine newContextFun func() Context } func newRouterGin(newContextFun func() Context) routerInterface { engine := gin.Default() router := &routerGin{ engine: engine, newContextFun: newContextFun, } return router } func (router *routerGin) Use(middlewares ...HandlerFunc) { router.engine.Use(getGinHandlerFun(router.newContextFun, middlewares...)...) } func (router *routerGin) Group(path string, handlers ...HandlerFunc) routeGroupInterface { ginGroup := router.engine.Group(path, getGinHandlerFun(router.newContextFun, handlers...)...) group := &routerGroupGin{ group: ginGroup, newContextFun: router.newContextFun, } return group } func (router *routerGin) Get(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { router.engine.GET(path, getGinHandlerFun(router.newContextFun, handlers...)...) } else { router.engine.GET(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...) } return router } func (router *routerGin) Post(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { router.engine.POST(path, getGinHandlerFun(router.newContextFun, handlers...)...) } else { router.engine.POST(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...) } return router } func (router *routerGin) Put(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { router.engine.PUT(path, getGinHandlerFun(router.newContextFun, handlers...)...) } else { router.engine.PUT(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...) } return router } func (router *routerGin) Delete(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface { if req == nil { router.engine.DELETE(path, getGinHandlerFun(router.newContextFun, handlers...)...) } else { router.engine.DELETE(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...) } return router } func getGinHandlerFun(newContextFun func() Context, handlers ...HandlerFunc) []gin.HandlerFunc { list := make([]gin.HandlerFunc, 0, len(handlers)) for _, handler := range handlers { list = append(list, func(rawCtx *gin.Context) { rawCtx1 := &ginCtx{ctx: rawCtx} ctx := newContextFun() ctx.SetRawContext(rawCtx1) reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(ctx)}) if rawCtx.IsAborted() { return } }) } return list } func getGinHandlerFunWithRequest(newContextFun func() Context, requestTemplate any, handlers ...HandlerFunc) []gin.HandlerFunc { list := make([]gin.HandlerFunc, 0, len(handlers)) for _, handler := range handlers { list = append(list, func(rawCtx *gin.Context) { rawCtx1 := &ginCtx{ctx: rawCtx} ctx := newContextFun() ctx.SetRawContext(rawCtx1) request, err := rawCtx1.parseRequest(requestTemplate) if err != nil { ParseParamsErrorHandler(ctx, rawCtx.Request.RequestURI, err) return } reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(request)}) }) } return list } func (router *routerGin) Run(addr string) error { return router.engine.Run(addr) } type ginCtx struct { ctx *gin.Context } func (ctx *ginCtx) Json(code int, v any) { ctx.ctx.JSON(code, v) } func (ctx *ginCtx) Writer() io.Writer { return ctx.ctx.Writer } func (ctx *ginCtx) body() ([]byte, error) { buf, err := io.ReadAll(ctx.ctx.Request.Body) if err != nil { return nil, fmt.Errorf("read body error:%v", err) } return buf, nil } func (ctx *ginCtx) parseRequest(request any) (any, error) { if request == nil { return nil, nil } requestType := reflect.TypeOf(request) newRequest := reflect.New(requestType).Interface() body, err := ctx.body() if err != nil { return nil, err } if len(body) == 0 { newRequestValue := reflect.ValueOf(newRequest).Elem() newRequestValueType := newRequestValue.Type() for i := 0; i < newRequestValue.NumField(); i++ { f := newRequestValueType.Field(i) fieldTagName := f.Tag.Get("json") if fieldTagName == "" { fieldTagName = f.Name } field := newRequestValue.Field(i) if !field.CanSet() { continue } fieldStr := ctx.ctx.Query(fieldTagName) err := setValue(field, fieldStr, f.Tag.Get("default"), f.Tag.Get("required")) if err != nil { return nil, fmt.Errorf("parse uri params field(%v) set value(%v) error:%v", f.Name, fieldStr, err) } } } else { err = json.Unmarshal(body, newRequest) if err != nil { return nil, fmt.Errorf("json unmarshal body error:%v", err) } } return newRequest, nil } // setValue 设置结构体一个字段的值 func setValue(field reflect.Value, value string, defaultValue string, required string) error { if value == "" { value = defaultValue } if value == "" && required == "true" { return fmt.Errorf("field is required, please give a valid value") } if field.Kind() == reflect.Ptr { if value == "" { return nil } if field.IsNil() { field.Set(reflect.New(field.Type().Elem())) } field = field.Elem() } switch field.Kind() { case reflect.String: field.SetString(value) case reflect.Bool: if value == "" { field.SetBool(false) } else { b, err := strconv.ParseBool(value) if err != nil { return err } field.SetBool(b) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if value == "" { field.SetInt(0) } else { i, err := strconv.ParseInt(value, 0, field.Type().Bits()) if err != nil { return err } field.SetInt(i) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if value == "" { field.SetUint(0) break } ui, err := strconv.ParseUint(value, 0, field.Type().Bits()) if err != nil { return err } field.SetUint(ui) case reflect.Float32, reflect.Float64: if value == "" { field.SetFloat(0) break } f, err := strconv.ParseFloat(value, field.Type().Bits()) if err != nil { return err } field.SetFloat(f) case reflect.Struct: return fmt.Errorf("unsupport struct field:%v", field.Type()) case reflect.Slice: values := strings.Split(value, ",") if len(values) == 1 && values[0] == "" { values = []string{} } field.Set(reflect.MakeSlice(field.Type(), len(values), len(values))) for i := 0; i < len(values); i++ { err := setValue(field.Index(i), values[i], "", "") if err != nil { return err } } default: return fmt.Errorf("no support type %s", field.Type()) } return nil }