diff --git a/router.go b/router.go index c0125ef..3604ff2 100644 --- a/router.go +++ b/router.go @@ -6,35 +6,30 @@ import ( "strings" ) -type Route struct { - pattern string - required bool -} - type Router struct { *http.ServeMux - routes map[Route]http.Handler + routes map[string]http.Handler + requiredRoutes []string } -func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) { +func NewRouter(mux *http.ServeMux, requiredRoutes ...string) (ro *Router) { if mux == nil { mux = http.NewServeMux() } ro = &Router{ - ServeMux: mux, - routes: make(map[Route]http.Handler), + ServeMux: mux, + routes: make(map[string]http.Handler), + requiredRoutes: make([]string, 0), } - ro.AddRequiredRoutes(requiredRoutes) + ro.AddRequiredRoutes(requiredRoutes...) return } -func (ro *Router) AddRequiredRoutes(requiredRoutes []string) { +func (ro *Router) AddRequiredRoutes(requiredRoutes ...string) { if requiredRoutes == nil { return } - for _, pattern := range requiredRoutes { - ro.routes[Route{pattern: pattern, required: true}] = nil - } + ro.requiredRoutes = append(ro.requiredRoutes, requiredRoutes...) } func (ro *Router) Register(pattern string, server http.Handler) (err error) { @@ -44,22 +39,20 @@ func (ro *Router) Register(pattern string, server http.Handler) (err error) { if server == nil { return fmt.Errorf("server must be provided") } - srv, required := ro.routes[Route{pattern: pattern, required: true}] - if !required { - srv, _ = ro.routes[Route{pattern: pattern, required: false}] - } - if srv != nil { + srv, exists := ro.routes[pattern] + if exists && srv != nil { return fmt.Errorf("too many routes for same pattern (%s)", pattern) } - srv = server - ro.ServeMux.Handle(string(pattern), srv) + ro.routes[pattern] = server + ro.ServeMux.Handle(string(pattern), ro.routes[pattern]) return } func (ro *Router) Validate() (err error) { - for route, server := range ro.routes { - if route.required && server == nil { - err = fmt.Errorf("missing required route for pattern (%s)", route.pattern) + for _, requiredRoute := range ro.requiredRoutes { + _, ok := ro.routes[requiredRoute] + if !ok { + err = fmt.Errorf("missing required route (%s)", requiredRoute) return } }