uniugm/admin/lib/web/routes_group.go

229 lines
6.4 KiB
Go
Raw Normal View History

2025-04-22 15:46:48 +08:00
package web
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"reflect"
"strconv"
"strings"
)
type IContext interface {
HandleError(path string, err error)
HandleSuccess(rspData any)
}
type HandlerFunc any
type RoutesGroup struct {
newContextFun func(ctx *gin.Context) IContext
rawGroup *gin.RouterGroup
node *routesNode
}
func newRoutesGroup(rawGroup *gin.RouterGroup, newContextFun func(ctx *gin.Context) IContext) *RoutesGroup {
g := &RoutesGroup{
rawGroup: rawGroup,
newContextFun: newContextFun,
node: &routesNode{},
}
return g
}
func (group *RoutesGroup) Group(path string, desc string, handlers ...HandlerFunc) *RoutesGroup {
rawGroup := group.rawGroup.Group(path, getGinHandlerFunWithRequest(group.newContextFun, handlers...)...)
nextRoutesGroup := newRoutesGroup(rawGroup, group.newContextFun)
nextGroupNode := group.addChildren("", path, desc, 0, handlers...)
nextRoutesGroup.node = nextGroupNode
return nextRoutesGroup
}
func (group *RoutesGroup) Use(middlewares ...HandlerFunc) {
group.rawGroup.Use(getGinHandlerFunWithRequest(group.newContextFun, middlewares...)...)
}
func (group *RoutesGroup) Get(path string, desc string, permit int, handlers ...HandlerFunc) {
group.rawGroup.GET(path, getGinHandlerFunWithRequest(group.newContextFun, handlers...)...)
group.addChildren("GET", path, desc, permit, handlers...)
}
func (group *RoutesGroup) Post(path string, desc string, permit int, handlers ...HandlerFunc) {
group.rawGroup.POST(path, getGinHandlerFunWithRequest(group.newContextFun, handlers...)...)
group.addChildren("POST", path, desc, permit, handlers...)
}
func (group *RoutesGroup) Put(path string, desc string, permit int, handlers ...HandlerFunc) {
group.rawGroup.PUT(path, getGinHandlerFunWithRequest(group.newContextFun, handlers...)...)
group.addChildren("PUT", path, desc, permit, handlers...)
}
func (group *RoutesGroup) Delete(path string, desc string, permit int, handlers ...HandlerFunc) {
group.rawGroup.DELETE(path, getGinHandlerFunWithRequest(group.newContextFun, handlers...)...)
group.addChildren("DELETE", path, desc, permit, handlers...)
}
func (group *RoutesGroup) addChildren(method, path string, desc string, permit int, handlers ...HandlerFunc) *routesNode {
return group.node.addChildren(method, path, desc, permit, handlers...)
}
func getGinHandlerFunWithRequest(newContextFun func(ctx *gin.Context) IContext, handlers ...HandlerFunc) []gin.HandlerFunc {
list := make([]gin.HandlerFunc, 0, len(handlers))
for _, handler := range handlers {
list = append(list, func(rawCtx *gin.Context) {
ctx := newContextFun(rawCtx)
handlerTo := reflect.TypeOf(handler)
numParams := handlerTo.NumIn()
if numParams != 3 {
ctx.HandleError(rawCtx.Request.RequestURI, fmt.Errorf("register callback handler params len(%v) invalid", numParams))
return
}
paramsTo := handlerTo.In(1)
params := reflect.New(paramsTo.Elem()).Interface()
err := parseRequest(rawCtx, params)
if err != nil {
ctx.HandleError(rawCtx.Request.RequestURI, err)
return
}
rspTo := handlerTo.In(2)
rsp := reflect.New(rspTo.Elem()).Interface()
rets := reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(params), reflect.ValueOf(rsp)})
errInt := rets[0]
if errInt.Interface() == nil {
ctx.HandleSuccess(rsp)
} else {
err, ok := errInt.Interface().(error)
if ok {
ctx.HandleError(rawCtx.Request.RequestURI, err)
} else {
ctx.HandleError(rawCtx.Request.RequestURI, fmt.Errorf("handle request ret error, but parse return error:%+v not ok", err))
}
}
})
}
return list
}
func parseRequest(rawCtx *gin.Context, params any) error {
bodyBuf, err := io.ReadAll(rawCtx.Request.Body)
if err != nil {
return fmt.Errorf("read body error:%v", err)
}
if len(bodyBuf) == 0 {
newRequestValue := reflect.ValueOf(params).Elem()
newRequestValueType := newRequestValue.Type()
for i := 0; i < newRequestValue.NumField(); i++ {
f := newRequestValueType.Field(i)
fieldTagName := f.Tag.Get("json")
if fieldTagName == "" {
fieldTagName = f.Name
}
field := newRequestValue.Field(i)
if !field.CanSet() {
continue
}
fieldStr := rawCtx.Query(fieldTagName)
err := setValue(field, fieldStr, f.Tag.Get("default"), f.Tag.Get("required"))
if err != nil {
return fmt.Errorf("parse uri params field(%v) set value(%v) error:%v", f.Name, fieldStr, err)
}
}
} else {
err = json.Unmarshal(bodyBuf, params)
if err != nil {
return fmt.Errorf("json unmarshal body error:%v", err)
}
}
return nil
}
// setValue 设置结构体一个字段的值
func setValue(field reflect.Value, value string, defaultValue string, required string) error {
if value == "" {
value = defaultValue
}
if value == "" && required == "true" {
return fmt.Errorf("field is required, please give a valid value")
}
if field.Kind() == reflect.Ptr {
if value == "" {
return nil
}
if field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Elem()
}
switch field.Kind() {
case reflect.String:
field.SetString(value)
case reflect.Bool:
if value == "" {
field.SetBool(false)
} else {
b, err := strconv.ParseBool(value)
if err != nil {
return err
}
field.SetBool(b)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if value == "" {
field.SetInt(0)
} else {
i, err := strconv.ParseInt(value, 0, field.Type().Bits())
if err != nil {
return err
}
field.SetInt(i)
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if value == "" {
field.SetUint(0)
break
}
ui, err := strconv.ParseUint(value, 0, field.Type().Bits())
if err != nil {
return err
}
field.SetUint(ui)
case reflect.Float32, reflect.Float64:
if value == "" {
field.SetFloat(0)
break
}
f, err := strconv.ParseFloat(value, field.Type().Bits())
if err != nil {
return err
}
field.SetFloat(f)
case reflect.Struct:
return fmt.Errorf("unsupport struct field:%v", field.Type())
case reflect.Slice:
values := strings.Split(value, ",")
if len(values) == 1 && values[0] == "" {
values = []string{}
}
field.Set(reflect.MakeSlice(field.Type(), len(values), len(values)))
for i := 0; i < len(values); i++ {
err := setValue(field.Index(i), values[i], "", "")
if err != nil {
return err
}
}
default:
return fmt.Errorf("no support type %s", field.Type())
}
return nil
}