304 lines
8.4 KiB
Go
304 lines
8.4 KiB
Go
package web
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type routerGroupGin struct {
|
|
group *gin.RouterGroup
|
|
newContextFun func() Context
|
|
}
|
|
|
|
func (group *routerGroupGin) Use(middlewares ...HandlerFunc) {
|
|
group.group.Use(getGinHandlerFun(group.newContextFun, middlewares...)...)
|
|
}
|
|
func (group *routerGroupGin) Group(path string, handlers ...HandlerFunc) routeGroupInterface {
|
|
ginGroup := group.group.Group(path, getGinHandlerFun(group.newContextFun, handlers...)...)
|
|
group1 := &routerGroupGin{group: ginGroup, newContextFun: group.newContextFun}
|
|
return group1
|
|
}
|
|
func (group *routerGroupGin) Get(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
group.group.GET(path, getGinHandlerFun(group.newContextFun, handlers...)...)
|
|
} else {
|
|
group.group.GET(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...)
|
|
}
|
|
return group
|
|
}
|
|
func (group *routerGroupGin) Post(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
group.group.POST(path, getGinHandlerFun(group.newContextFun, handlers...)...)
|
|
} else {
|
|
group.group.POST(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...)
|
|
}
|
|
return group
|
|
}
|
|
|
|
func (group *routerGroupGin) Put(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
group.group.PUT(path, getGinHandlerFun(group.newContextFun, handlers...)...)
|
|
} else {
|
|
group.group.PUT(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...)
|
|
}
|
|
return group
|
|
}
|
|
|
|
func (group *routerGroupGin) Delete(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
group.group.DELETE(path, getGinHandlerFun(group.newContextFun, handlers...)...)
|
|
} else {
|
|
group.group.DELETE(path, getGinHandlerFunWithRequest(group.newContextFun, req, handlers...)...)
|
|
}
|
|
return group
|
|
}
|
|
|
|
type routerGin struct {
|
|
engine *gin.Engine
|
|
newContextFun func() Context
|
|
}
|
|
|
|
func newRouterGin(newContextFun func() Context) routerInterface {
|
|
engine := gin.Default()
|
|
router := &routerGin{
|
|
engine: engine,
|
|
newContextFun: newContextFun,
|
|
}
|
|
return router
|
|
}
|
|
|
|
func (router *routerGin) Use(middlewares ...HandlerFunc) {
|
|
router.engine.Use(getGinHandlerFun(router.newContextFun, middlewares...)...)
|
|
}
|
|
func (router *routerGin) Group(path string, handlers ...HandlerFunc) routeGroupInterface {
|
|
ginGroup := router.engine.Group(path, getGinHandlerFun(router.newContextFun, handlers...)...)
|
|
group := &routerGroupGin{
|
|
group: ginGroup,
|
|
newContextFun: router.newContextFun,
|
|
}
|
|
return group
|
|
}
|
|
func (router *routerGin) Get(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
router.engine.GET(path, getGinHandlerFun(router.newContextFun, handlers...)...)
|
|
} else {
|
|
router.engine.GET(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...)
|
|
}
|
|
return router
|
|
}
|
|
func (router *routerGin) Post(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
router.engine.POST(path, getGinHandlerFun(router.newContextFun, handlers...)...)
|
|
} else {
|
|
router.engine.POST(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...)
|
|
}
|
|
return router
|
|
}
|
|
|
|
func (router *routerGin) Put(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
router.engine.PUT(path, getGinHandlerFun(router.newContextFun, handlers...)...)
|
|
} else {
|
|
router.engine.PUT(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...)
|
|
}
|
|
return router
|
|
}
|
|
|
|
func (router *routerGin) Delete(path string, desc string, req any, handlers ...HandlerFunc) routeGroupInterface {
|
|
if req == nil {
|
|
router.engine.DELETE(path, getGinHandlerFun(router.newContextFun, handlers...)...)
|
|
} else {
|
|
router.engine.DELETE(path, getGinHandlerFunWithRequest(router.newContextFun, req, handlers...)...)
|
|
}
|
|
return router
|
|
}
|
|
|
|
func getGinHandlerFun(newContextFun func() Context, handlers ...HandlerFunc) []gin.HandlerFunc {
|
|
list := make([]gin.HandlerFunc, 0, len(handlers))
|
|
for _, handler := range handlers {
|
|
list = append(list, func(rawCtx *gin.Context) {
|
|
rawCtx1 := &ginCtx{ctx: rawCtx}
|
|
ctx := newContextFun()
|
|
ctx.SetRawContext(rawCtx1)
|
|
reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(ctx)})
|
|
if rawCtx.IsAborted() {
|
|
return
|
|
}
|
|
})
|
|
}
|
|
return list
|
|
}
|
|
|
|
func getGinHandlerFunWithRequest(newContextFun func() Context, requestTemplate any, handlers ...HandlerFunc) []gin.HandlerFunc {
|
|
list := make([]gin.HandlerFunc, 0, len(handlers))
|
|
for _, handler := range handlers {
|
|
list = append(list, func(rawCtx *gin.Context) {
|
|
rawCtx1 := &ginCtx{ctx: rawCtx}
|
|
ctx := newContextFun()
|
|
ctx.SetRawContext(rawCtx1)
|
|
request, err := rawCtx1.parseRequest(requestTemplate)
|
|
if err != nil {
|
|
ParseParamsErrorHandler(ctx, rawCtx.Request.RequestURI, err)
|
|
return
|
|
}
|
|
reflect.ValueOf(handler).Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(request)})
|
|
})
|
|
}
|
|
return list
|
|
}
|
|
|
|
func (router *routerGin) Run(addr string) error {
|
|
return router.engine.Run(addr)
|
|
}
|
|
|
|
type ginCtx struct {
|
|
ctx *gin.Context
|
|
}
|
|
|
|
func (ctx *ginCtx) Json(code int, v any) {
|
|
ctx.ctx.JSON(code, v)
|
|
}
|
|
|
|
func (ctx *ginCtx) Writer() io.Writer {
|
|
return ctx.ctx.Writer
|
|
}
|
|
|
|
func (ctx *ginCtx) body() ([]byte, error) {
|
|
buf, err := io.ReadAll(ctx.ctx.Request.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read body error:%v", err)
|
|
}
|
|
|
|
return buf, nil
|
|
}
|
|
|
|
func (ctx *ginCtx) parseRequest(request any) (any, error) {
|
|
if request == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
requestType := reflect.TypeOf(request)
|
|
newRequest := reflect.New(requestType).Interface()
|
|
|
|
body, err := ctx.body()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(body) == 0 {
|
|
newRequestValue := reflect.ValueOf(newRequest).Elem()
|
|
newRequestValueType := newRequestValue.Type()
|
|
for i := 0; i < newRequestValue.NumField(); i++ {
|
|
f := newRequestValueType.Field(i)
|
|
fieldTagName := f.Tag.Get("json")
|
|
if fieldTagName == "" {
|
|
fieldTagName = f.Name
|
|
}
|
|
|
|
field := newRequestValue.Field(i)
|
|
if !field.CanSet() {
|
|
continue
|
|
}
|
|
fieldStr := ctx.ctx.Query(fieldTagName)
|
|
err := setValue(field, fieldStr, f.Tag.Get("default"), f.Tag.Get("required"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse uri params field(%v) set value(%v) error:%v", f.Name, fieldStr, err)
|
|
}
|
|
}
|
|
} else {
|
|
err = json.Unmarshal(body, newRequest)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("json unmarshal body error:%v", err)
|
|
}
|
|
}
|
|
return newRequest, nil
|
|
}
|
|
|
|
// setValue 设置结构体一个字段的值
|
|
func setValue(field reflect.Value, value string, defaultValue string, required string) error {
|
|
if value == "" {
|
|
value = defaultValue
|
|
}
|
|
if value == "" && required == "true" {
|
|
return fmt.Errorf("field is required, please give a valid value")
|
|
}
|
|
|
|
if field.Kind() == reflect.Ptr {
|
|
if value == "" {
|
|
return nil
|
|
}
|
|
|
|
if field.IsNil() {
|
|
field.Set(reflect.New(field.Type().Elem()))
|
|
}
|
|
field = field.Elem()
|
|
}
|
|
switch field.Kind() {
|
|
case reflect.String:
|
|
field.SetString(value)
|
|
case reflect.Bool:
|
|
if value == "" {
|
|
field.SetBool(false)
|
|
} else {
|
|
b, err := strconv.ParseBool(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetBool(b)
|
|
}
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
if value == "" {
|
|
field.SetInt(0)
|
|
} else {
|
|
i, err := strconv.ParseInt(value, 0, field.Type().Bits())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetInt(i)
|
|
}
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
if value == "" {
|
|
field.SetUint(0)
|
|
break
|
|
}
|
|
ui, err := strconv.ParseUint(value, 0, field.Type().Bits())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetUint(ui)
|
|
case reflect.Float32, reflect.Float64:
|
|
if value == "" {
|
|
field.SetFloat(0)
|
|
break
|
|
}
|
|
f, err := strconv.ParseFloat(value, field.Type().Bits())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetFloat(f)
|
|
case reflect.Struct:
|
|
return fmt.Errorf("unsupport struct field:%v", field.Type())
|
|
case reflect.Slice:
|
|
values := strings.Split(value, ",")
|
|
if len(values) == 1 && values[0] == "" {
|
|
values = []string{}
|
|
}
|
|
field.Set(reflect.MakeSlice(field.Type(), len(values), len(values)))
|
|
for i := 0; i < len(values); i++ {
|
|
err := setValue(field.Index(i), values[i], "", "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
default:
|
|
return fmt.Errorf("no support type %s", field.Type())
|
|
}
|
|
return nil
|
|
}
|