From fa115c269101ee52384e663e5c4a0cf44272d902 Mon Sep 17 00:00:00 2001 From: yujular <22210720287@m.fudan.edu.cn> Date: Thu, 19 Oct 2023 15:51:55 +0800 Subject: [PATCH] feat: update validator, add more message --- errors.go | 146 ++++++++++++++++++++++++++++++++++++++++--------- go.mod | 4 +- go.sum | 6 ++ logger.go | 66 +++++++++++++++++++++- logger_test.go | 2 +- validate.go | 86 ++++++++++++++++++++++++++++- 6 files changed, 279 insertions(+), 31 deletions(-) diff --git a/errors.go b/errors.go index 85a1a9f..5cfe4fd 100644 --- a/errors.go +++ b/errors.go @@ -2,85 +2,173 @@ package common import ( "errors" + "reflect" "strconv" + "strings" "github.com/gofiber/fiber/v2" "gorm.io/gorm" ) -type HttpError struct { - Code int `json:"code,omitempty"` - Message string `json:"message,omitempty"` - Detail *ErrorDetail `json:"detail,omitempty"` +// ValidateFieldError is the error detail for validation errors +// +// see validator.FieldError for more information +type ValidateFieldError struct { + // Tag is the validation tag that failed. + // use alias if defined + // + // e.g. "required", "min", "max", etc. + Tag string `json:"tag"` + + // Field is the field name that failed validation + // use registered tag name if registered + Field string `json:"field"` + + // Kind is the kind of the field type + Kind reflect.Kind `json:"-"` + + // Param is the parameter for the validation + Param string `json:"param"` + + // Value is the actual value that failed validation + Value any `json:"value"` + + // Message is the error message + Message string `json:"message"` +} + +func (e *ValidateFieldError) Error() string { + if e.Message != "" { + return e.Message + } + + // construct error message + // if you create a custom validation tag, you may need to switch case here + switch e.Tag { + case "min": + if e.Kind == reflect.String { + e.Message = e.Field + "至少" + e.Param + "字符" + } else { + e.Message = e.Field + "至少为" + e.Param + } + case "max": + if e.Kind == reflect.String { + e.Message = e.Field + "限长" + e.Param + "字符" + } else { + e.Message = e.Field + "至多为" + e.Param + } + case "required": + e.Message = e.Field + "不能为空" + case "email": + e.Message = e.Field + "格式不正确" + } + + return e.Message +} + +// ValidationErrors is a list of ValidateFieldError +// for use in custom error messages post validation +// +// see validator.ValidationErrors for more information +type ValidationErrors []ValidateFieldError + +func (e ValidationErrors) Error() string { + if len(e) == 0 { + return "Validation Error" + } + + if len(e) == 1 { + return e[0].Error() + } + + var stringBuilder strings.Builder + stringBuilder.WriteString(e[0].Error()) + for _, err := range e[1:] { + stringBuilder.WriteString(", ") + stringBuilder.WriteString(err.Error()) + } + return stringBuilder.String() +} + +type HttpBaseError struct { + Code int `json:"code"` + Message string `json:"message"` } -func (e *HttpError) Error() string { +func (e *HttpBaseError) Error() string { return e.Message } -func BadRequest(messages ...string) *HttpError { +type HttpError struct { + HttpBaseError + ValidationDetail ValidationErrors `json:"validation_detail,omitempty"` +} + +func BadRequest(messages ...string) *HttpBaseError { message := "Bad Request" if len(messages) > 0 { message = messages[0] } - return &HttpError{ + return &HttpBaseError{ Code: 400, Message: message, } } -func Unauthorized(messages ...string) *HttpError { +func Unauthorized(messages ...string) *HttpBaseError { message := "Invalid JWT Token" if len(messages) > 0 { message = messages[0] } - return &HttpError{ + return &HttpBaseError{ Code: 401, Message: message, } } -func Forbidden(messages ...string) *HttpError { +func Forbidden(messages ...string) *HttpBaseError { message := "Forbidden" if len(messages) > 0 { message = messages[0] } - return &HttpError{ + return &HttpBaseError{ Code: 403, Message: message, } } -func NotFound(messages ...string) *HttpError { +func NotFound(messages ...string) *HttpBaseError { message := "Not Found" if len(messages) > 0 { message = messages[0] } - return &HttpError{ + return &HttpBaseError{ Code: 404, Message: message, } } -func InternalServerError(messages ...string) *HttpError { +func InternalServerError(messages ...string) *HttpBaseError { message := "Internal Server Error" if len(messages) > 0 { message = messages[0] } - return &HttpError{ + return &HttpBaseError{ Code: 500, Message: message, } } -func ErrorHandler(ctx *fiber.Ctx, err error) error { +func ErrorHandler(c *fiber.Ctx, err error) error { if err == nil { return nil } - httpError := HttpError{ - Code: 500, - Message: err.Error(), + var httpError = HttpError{ + HttpBaseError: HttpBaseError{ + Code: 500, + Message: err.Error(), + }, } if errors.Is(err, gorm.ErrRecordNotFound) { @@ -89,17 +177,25 @@ func ErrorHandler(ctx *fiber.Ctx, err error) error { switch e := err.(type) { case *HttpError: httpError = *e + case *HttpBaseError: + httpError.HttpBaseError = *e case *fiber.Error: httpError.Code = e.Code - case *ErrorDetail: + case *ValidationErrors: httpError.Code = 400 - httpError.Detail = e + httpError.ValidationDetail = *e case fiber.MultiError: httpError.Code = 400 - httpError.Message = "" - for _, err = range e { - httpError.Message += err.Error() + "\n" + + var stringBuilder strings.Builder + for _, err := range e { + stringBuilder.WriteString(err.Error()) + stringBuilder.WriteString("\n") } + httpError.Message = stringBuilder.String() + default: + httpError.Code = 500 + httpError.Message = err.Error() } } @@ -118,5 +214,5 @@ func ErrorHandler(ctx *fiber.Ctx, err error) error { } } - return ctx.Status(statusCode).JSON(&httpError) + return c.Status(statusCode).JSON(&httpError) } diff --git a/go.mod b/go.mod index c983ae8..5d44a84 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,13 @@ go 1.20 require ( github.com/creasty/defaults v1.7.0 - github.com/go-playground/validator/v10 v10.15.1 + github.com/go-playground/validator/v10 v10.15.4 github.com/goccy/go-json v0.10.2 github.com/gofiber/fiber/v2 v2.48.0 github.com/hetiansu5/urlquery v1.2.7 github.com/rs/zerolog v1.30.0 github.com/stretchr/testify v1.8.4 + go.uber.org/zap v1.26.0 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 gorm.io/gorm v1.25.4 ) @@ -33,6 +34,7 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.48.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect + go.uber.org/multierr v1.10.0 // indirect golang.org/x/crypto v0.12.0 // indirect golang.org/x/net v0.14.0 // indirect golang.org/x/sys v0.11.0 // indirect diff --git a/go.sum b/go.sum index 5244ad4..e27ffb9 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,7 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.15.1 h1:BSe8uhN+xQ4r5guV/ywQI4gO59C2raYcGffYWZEjZzM= github.com/go-playground/validator/v10 v10.15.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-playground/validator/v10 v10.15.4/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -64,6 +65,11 @@ github.com/valyala/fasthttp v1.48.0 h1:oJWvHb9BIZToTQS3MuQ2R3bJZiNSa2KiNdeI8A+79 github.com/valyala/fasthttp v1.48.0/go.mod h1:k2zXd82h/7UZc3VOdJ2WaUqt1uZ/XpXAfE9i+HBC3lA= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= +go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= diff --git a/logger.go b/logger.go index e61303d..f02feb9 100644 --- a/logger.go +++ b/logger.go @@ -4,13 +4,77 @@ import ( "os" "time" + "github.com/gofiber/fiber/v2" "github.com/rs/zerolog" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) -var Logger = zerolog.New(os.Stdout).With().Timestamp().Logger() +var LoggerOld = zerolog.New(os.Stdout).With().Timestamp().Logger() func init() { // compatible with zap and old spec zerolog.MessageFieldName = "msg" zerolog.TimeFieldFormat = time.RFC3339Nano } + +const KEY = "zapLogger" + +type Logger struct { + *zap.Logger +} + +func NewLogger() (*Logger, func()) { + // Info Level + logger, err := initZap() + if err != nil { + panic(err) + } + return &Logger{Logger: logger}, func() { + _ = logger.Sync() + } +} + +func initZap() (*zap.Logger, error) { + var atomicLevel zapcore.Level + // Info Level, production env + atomicLevel = zapcore.InfoLevel + + logConfig := zap.Config{ + Level: zap.NewAtomicLevelAt(atomicLevel), + Development: false, + Encoding: "json", + EncoderConfig: zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + MessageKey: "msg", + StacktraceKey: "stacktrace", + EncodeLevel: zapcore.LowercaseLevelEncoder, + EncodeTime: zapcore.RFC3339NanoTimeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeName: zapcore.FullNameEncoder, + }, + OutputPaths: []string{"stdout"}, + ErrorOutputPaths: []string{"stderr"}, + } + + return logConfig.Build(zap.AddStacktrace(zap.ErrorLevel), zap.AddCaller()) +} + +// NewContext Adds a field to the specified context +func (l *Logger) NewContext(c *fiber.Ctx, fields ...zapcore.Field) { + c.Locals(KEY, &Logger{l.WithContext(c).With(fields...)}) +} + +// WithContext Returns a zap instance from the specified context +func (l *Logger) WithContext(c *fiber.Ctx) *Logger { + if c == nil { + return l + } + ctxLogger, ok := c.Locals(KEY).(*Logger) + if ok { + return ctxLogger + } + return l +} diff --git a/logger_test.go b/logger_test.go index a7725b6..e12011c 100644 --- a/logger_test.go +++ b/logger_test.go @@ -11,5 +11,5 @@ func TestLog(t *testing.T) { log.Info().Msg("hello world") - Logger.Info().Msg("hello world") + LoggerOld.Info().Msg("hello world") } diff --git a/validate.go b/validate.go index 4c6ad6c..4d07863 100644 --- a/validate.go +++ b/validate.go @@ -1,12 +1,15 @@ package common import ( + "context" + "errors" "reflect" "strings" "github.com/creasty/defaults" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" + "go.uber.org/zap" ) type ErrorDetailElement struct { @@ -23,6 +26,8 @@ func (e *ErrorDetail) Error() string { var Validate = validator.New() +var logger *Logger + func init() { Validate.RegisterTagNameFunc(func(fld reflect.StructField) string { name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] @@ -35,7 +40,7 @@ func init() { }) } -func ValidateStruct(model any) error { +func ValidateStructOld(model any) error { errors := Validate.Struct(model) if errors != nil { var errorDetail ErrorDetail @@ -52,6 +57,40 @@ func ValidateStruct(model any) error { return nil } +func ValidateStruct(ctx context.Context, model any) error { + err := Validate.StructCtx(ctx, model) + + if err == nil { + return nil + } + + var rawValidationErrors validator.ValidationErrors + if ok := errors.As(err, &rawValidationErrors); ok { + var validationErrors ValidationErrors + for _, fe := range rawValidationErrors { + validationErrors = append(validationErrors, + ValidateFieldError{ + Tag: fe.Tag(), + Field: fe.Field(), + Kind: fe.Kind(), + Param: fe.Param(), + Value: fe.Value(), + }, + ) + } + return &validationErrors + } + + var invalidValidationError *validator.InvalidValidationError + if ok := errors.As(err, &invalidValidationError); ok { + logger.Error("invalid validation error", zap.Error(err)) + return err + } + + logger.Error("unknown validation error", zap.Error(err)) + return err +} + // ValidateQuery parse, set default and validate query into model func ValidateQuery(c *fiber.Ctx, model any) error { // parse query into struct @@ -68,7 +107,25 @@ func ValidateQuery(c *fiber.Ctx, model any) error { } // Validate - return ValidateStruct(model) + return ValidateStruct(c.Context(), model) +} + +func ValidateQueryOld(c *fiber.Ctx, model any) error { + // parse query into struct + // see https://docs.gofiber.io/api/ctx/#queryparser + err := c.QueryParser(model) + if err != nil { + return BadRequest(err.Error()) + } + + // set default value + err = defaults.Set(model) + if err != nil { + return err + } + + // Validate + return ValidateStructOld(model) } // ValidateBody parse, set default and validate body based on Content-Type. @@ -76,6 +133,29 @@ func ValidateQuery(c *fiber.Ctx, model any) error { func ValidateBody(c *fiber.Ctx, model any) error { body := c.Body() + // empty request body, return default value + if len(body) > 0 { + // parse json, xml and form by fiber.BodyParser into struct + // see https://docs.gofiber.io/api/ctx/#bodyparser + err := c.BodyParser(model) + if err != nil { + return BadRequest(err.Error()) + } + } + + // set default value + err := defaults.Set(model) + if err != nil { + return err + } + + // Validate + return ValidateStruct(c.Context(), model) +} + +func ValidateBodyOld(c *fiber.Ctx, model any) error { + body := c.Body() + // empty request body, return default value if len(body) == 0 { return defaults.Set(model) @@ -95,5 +175,5 @@ func ValidateBody(c *fiber.Ctx, model any) error { } // Validate - return ValidateStruct(model) + return ValidateStructOld(model) }