2024-10-08 22:58:08 +00:00
|
|
|
package router
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Router struct {
|
|
|
|
*http.ServeMux
|
2024-10-10 20:04:40 +00:00
|
|
|
routes map[string]http.Handler
|
|
|
|
requiredRoutes []string
|
2024-10-08 22:58:08 +00:00
|
|
|
}
|
|
|
|
|
2024-10-10 20:04:40 +00:00
|
|
|
func NewRouter(mux *http.ServeMux, requiredRoutes ...string) (ro *Router) {
|
2024-10-08 22:58:08 +00:00
|
|
|
if mux == nil {
|
|
|
|
mux = http.NewServeMux()
|
|
|
|
}
|
|
|
|
ro = &Router{
|
2024-10-10 20:04:40 +00:00
|
|
|
ServeMux: mux,
|
|
|
|
routes: make(map[string]http.Handler),
|
|
|
|
requiredRoutes: make([]string, 0),
|
2024-10-08 22:58:08 +00:00
|
|
|
}
|
2024-10-10 20:04:40 +00:00
|
|
|
ro.AddRequiredRoutes(requiredRoutes...)
|
2024-10-08 23:06:09 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-10-10 20:04:40 +00:00
|
|
|
func (ro *Router) AddRequiredRoutes(requiredRoutes ...string) {
|
2024-10-08 23:06:09 +00:00
|
|
|
if requiredRoutes == nil {
|
|
|
|
return
|
|
|
|
}
|
2024-10-10 20:04:40 +00:00
|
|
|
ro.requiredRoutes = append(ro.requiredRoutes, requiredRoutes...)
|
2024-10-08 22:58:08 +00:00
|
|
|
}
|
|
|
|
|
2024-10-10 18:26:07 +00:00
|
|
|
func (ro *Router) Register(pattern string, server http.Handler) (err error) {
|
2024-10-08 22:58:08 +00:00
|
|
|
if !strings.HasPrefix(pattern, "/") {
|
|
|
|
return fmt.Errorf("missing preceding slash in pattern (%s)", pattern)
|
|
|
|
}
|
|
|
|
if server == nil {
|
2024-10-09 21:09:12 +00:00
|
|
|
return fmt.Errorf("server must be provided")
|
2024-10-08 22:58:08 +00:00
|
|
|
}
|
2024-10-10 20:04:40 +00:00
|
|
|
srv, exists := ro.routes[pattern]
|
|
|
|
if exists && srv != nil {
|
2024-10-08 22:58:08 +00:00
|
|
|
return fmt.Errorf("too many routes for same pattern (%s)", pattern)
|
|
|
|
}
|
2024-10-10 20:04:40 +00:00
|
|
|
ro.routes[pattern] = server
|
|
|
|
ro.ServeMux.Handle(string(pattern), ro.routes[pattern])
|
2024-10-08 22:58:08 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ro *Router) Validate() (err error) {
|
2024-10-10 20:04:40 +00:00
|
|
|
for _, requiredRoute := range ro.requiredRoutes {
|
|
|
|
_, ok := ro.routes[requiredRoute]
|
|
|
|
if !ok {
|
|
|
|
err = fmt.Errorf("missing required route (%s)", requiredRoute)
|
2024-10-08 22:58:08 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func (ro *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
|
|
if err := ro.Validate(); err != nil {
|
|
|
|
panic(err)
|
|
|
|
}
|
|
|
|
ro.ServeMux.ServeHTTP(w, r)
|
|
|
|
}
|