Fix router requiredRoutes
This commit is contained in:
parent
d0004bba70
commit
cb88fedc53
41
router.go
41
router.go
@ -6,35 +6,30 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Route struct {
|
|
||||||
pattern string
|
|
||||||
required bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type Router struct {
|
type Router struct {
|
||||||
*http.ServeMux
|
*http.ServeMux
|
||||||
routes map[Route]http.Handler
|
routes map[string]http.Handler
|
||||||
|
requiredRoutes []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) {
|
func NewRouter(mux *http.ServeMux, requiredRoutes ...string) (ro *Router) {
|
||||||
if mux == nil {
|
if mux == nil {
|
||||||
mux = http.NewServeMux()
|
mux = http.NewServeMux()
|
||||||
}
|
}
|
||||||
ro = &Router{
|
ro = &Router{
|
||||||
ServeMux: mux,
|
ServeMux: mux,
|
||||||
routes: make(map[Route]http.Handler),
|
routes: make(map[string]http.Handler),
|
||||||
|
requiredRoutes: make([]string, 0),
|
||||||
}
|
}
|
||||||
ro.AddRequiredRoutes(requiredRoutes)
|
ro.AddRequiredRoutes(requiredRoutes...)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ro *Router) AddRequiredRoutes(requiredRoutes []string) {
|
func (ro *Router) AddRequiredRoutes(requiredRoutes ...string) {
|
||||||
if requiredRoutes == nil {
|
if requiredRoutes == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, pattern := range requiredRoutes {
|
ro.requiredRoutes = append(ro.requiredRoutes, requiredRoutes...)
|
||||||
ro.routes[Route{pattern: pattern, required: true}] = nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ro *Router) Register(pattern string, server http.Handler) (err error) {
|
func (ro *Router) Register(pattern string, server http.Handler) (err error) {
|
||||||
@ -44,22 +39,20 @@ func (ro *Router) Register(pattern string, server http.Handler) (err error) {
|
|||||||
if server == nil {
|
if server == nil {
|
||||||
return fmt.Errorf("server must be provided")
|
return fmt.Errorf("server must be provided")
|
||||||
}
|
}
|
||||||
srv, required := ro.routes[Route{pattern: pattern, required: true}]
|
srv, exists := ro.routes[pattern]
|
||||||
if !required {
|
if exists && srv != nil {
|
||||||
srv, _ = ro.routes[Route{pattern: pattern, required: false}]
|
|
||||||
}
|
|
||||||
if srv != nil {
|
|
||||||
return fmt.Errorf("too many routes for same pattern (%s)", pattern)
|
return fmt.Errorf("too many routes for same pattern (%s)", pattern)
|
||||||
}
|
}
|
||||||
srv = server
|
ro.routes[pattern] = server
|
||||||
ro.ServeMux.Handle(string(pattern), srv)
|
ro.ServeMux.Handle(string(pattern), ro.routes[pattern])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ro *Router) Validate() (err error) {
|
func (ro *Router) Validate() (err error) {
|
||||||
for route, server := range ro.routes {
|
for _, requiredRoute := range ro.requiredRoutes {
|
||||||
if route.required && server == nil {
|
_, ok := ro.routes[requiredRoute]
|
||||||
err = fmt.Errorf("missing required route for pattern (%s)", route.pattern)
|
if !ok {
|
||||||
|
err = fmt.Errorf("missing required route (%s)", requiredRoute)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user