router/router.go

75 lines
1.6 KiB
Go
Raw Normal View History

2024-10-08 22:58:08 +00:00
package router
import (
"fmt"
"net/http"
"strings"
)
type Route struct {
pattern string
required bool
}
type Router struct {
*http.ServeMux
routes map[Route]*server
}
func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) {
if mux == nil {
mux = http.NewServeMux()
}
ro = &Router{
ServeMux: mux,
routes: make(map[Route]*server),
}
2024-10-08 23:06:09 +00:00
ro.AddRequiredRoutes(requiredRoutes)
return
}
func (ro *Router) AddRequiredRoutes(requiredRoutes []string) {
if requiredRoutes == nil {
return
}
2024-10-08 22:58:08 +00:00
for _, pattern := range requiredRoutes {
2024-10-08 23:09:03 +00:00
ro.routes[Route{pattern: pattern, required: true}] = nil
2024-10-08 22:58:08 +00:00
}
}
func (ro *Router) Register(pattern string, server *server) (err error) {
if !strings.HasPrefix(pattern, "/") {
return fmt.Errorf("missing preceding slash in pattern (%s)", pattern)
}
if server == nil {
return fmt.Errorf("server must be set in register")
}
srv, required := ro.routes[Route{pattern: pattern, required: true}]
if !required {
srv, _ = ro.routes[Route{pattern: pattern, required: false}]
}
if srv != nil {
return fmt.Errorf("too many routes for same pattern (%s)", pattern)
}
srv = server
ro.ServeMux.Handle(string(pattern), srv)
return
}
func (ro *Router) Validate() (err error) {
for route, server := range ro.routes {
if route.required && server == nil {
err = fmt.Errorf("missing required route for pattern (%s)", route.pattern)
return
}
}
return
}
func (ro *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := ro.Validate(); err != nil {
panic(err)
}
ro.ServeMux.ServeHTTP(w, r)
}