Fix router requiredRoutes

This commit is contained in:
some 2024-10-10 16:04:40 -04:00
parent d0004bba70
commit cb88fedc53
Signed by: some
GPG Key ID: 65D0589220B9BFC8

View File

@ -6,35 +6,30 @@ import (
"strings" "strings"
) )
type Route struct {
pattern string
required bool
}
type Router struct { type Router struct {
*http.ServeMux *http.ServeMux
routes map[Route]http.Handler routes map[string]http.Handler
requiredRoutes []string
} }
func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) { func NewRouter(mux *http.ServeMux, requiredRoutes ...string) (ro *Router) {
if mux == nil { if mux == nil {
mux = http.NewServeMux() mux = http.NewServeMux()
} }
ro = &Router{ ro = &Router{
ServeMux: mux, ServeMux: mux,
routes: make(map[Route]http.Handler), routes: make(map[string]http.Handler),
requiredRoutes: make([]string, 0),
} }
ro.AddRequiredRoutes(requiredRoutes) ro.AddRequiredRoutes(requiredRoutes...)
return return
} }
func (ro *Router) AddRequiredRoutes(requiredRoutes []string) { func (ro *Router) AddRequiredRoutes(requiredRoutes ...string) {
if requiredRoutes == nil { if requiredRoutes == nil {
return return
} }
for _, pattern := range requiredRoutes { ro.requiredRoutes = append(ro.requiredRoutes, requiredRoutes...)
ro.routes[Route{pattern: pattern, required: true}] = nil
}
} }
func (ro *Router) Register(pattern string, server http.Handler) (err error) { func (ro *Router) Register(pattern string, server http.Handler) (err error) {
@ -44,22 +39,20 @@ func (ro *Router) Register(pattern string, server http.Handler) (err error) {
if server == nil { if server == nil {
return fmt.Errorf("server must be provided") return fmt.Errorf("server must be provided")
} }
srv, required := ro.routes[Route{pattern: pattern, required: true}] srv, exists := ro.routes[pattern]
if !required { if exists && srv != nil {
srv, _ = ro.routes[Route{pattern: pattern, required: false}]
}
if srv != nil {
return fmt.Errorf("too many routes for same pattern (%s)", pattern) return fmt.Errorf("too many routes for same pattern (%s)", pattern)
} }
srv = server ro.routes[pattern] = server
ro.ServeMux.Handle(string(pattern), srv) ro.ServeMux.Handle(string(pattern), ro.routes[pattern])
return return
} }
func (ro *Router) Validate() (err error) { func (ro *Router) Validate() (err error) {
for route, server := range ro.routes { for _, requiredRoute := range ro.requiredRoutes {
if route.required && server == nil { _, ok := ro.routes[requiredRoute]
err = fmt.Errorf("missing required route for pattern (%s)", route.pattern) if !ok {
err = fmt.Errorf("missing required route (%s)", requiredRoute)
return return
} }
} }