diff --git a/default.go b/default.go index 8e31dc3..67396ad 100644 --- a/default.go +++ b/default.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "net/url" ) type DefaultError uint32 @@ -66,6 +67,7 @@ func (e DefaultError) BodyBytes() (body []byte) { type DefaultRequestBuilder struct { errorHandler ErrorHandler allowedMethods *[]string + url *url.URL header struct { *Header fields header @@ -89,8 +91,13 @@ func (rb *DefaultRequestBuilder) SetErrorHandler(errorHandler ErrorHandler) *Def return rb } -func (rb *DefaultRequestBuilder) SetAllowedMethods(allowedMethods []string) *DefaultRequestBuilder { - rb.allowedMethods = &allowedMethods +func (rb *DefaultRequestBuilder) SetAllowedMethods(allowedMethods *[]string) *DefaultRequestBuilder { + rb.allowedMethods = allowedMethods + return rb +} + +func (rb *DefaultRequestBuilder) SetUrl(url *url.URL) *DefaultRequestBuilder { + rb.url = url return rb } @@ -120,6 +127,10 @@ func (rb *DefaultRequestBuilder) SetDefaults() *DefaultRequestBuilder { amd := make([]string, 0) rb.allowedMethods = &amd } + if rb.url == nil { + u := url.URL{} + rb.url = &u + } if rb.header.Header == nil || rb.header.fields == nil { hd := make(Header) rb.header.Header = &hd @@ -151,7 +162,12 @@ func (rb *DefaultRequestBuilder) Allowed(method string) (errRes ErrorResponse) { if !ok { return rb.errorHandler.ErrorResponse(DefaultErrorMethodNotAllowed) } - return DefaultOk + return rb.errorHandler.ErrorResponse(DefaultOk) +} + +func (rb *DefaultRequestBuilder) Url(url url.URL) (errRes ErrorResponse) { + *rb.url = url + return rb.errorHandler.ErrorResponse(DefaultOk) } func (rb *DefaultRequestBuilder) Header(header Header) (errRes ErrorResponse) { diff --git a/prototype.go b/prototype.go index 9cdc559..8039b2e 100644 --- a/prototype.go +++ b/prototype.go @@ -1,10 +1,14 @@ package router -import "io" +import ( + "io" + "net/url" +) type RequestBuilder interface { Prototype() (req PrototypeRequestBuilder) Allowed(method string) (errRes ErrorResponse) + Url(url url.URL) (errRes ErrorResponse) Header(header Header) (errRes ErrorResponse) ReadBody(body io.ReadCloser) (errRes ErrorResponse) Values(values Values) (errRes ErrorResponse) diff --git a/router.go b/router.go index 0bee885..f8fb855 100644 --- a/router.go +++ b/router.go @@ -42,7 +42,7 @@ func (ro *Router) Register(pattern string, server *server) (err error) { return fmt.Errorf("missing preceding slash in pattern (%s)", pattern) } if server == nil { - return fmt.Errorf("server must be set in register") + return fmt.Errorf("server must be provided") } srv, required := ro.routes[Route{pattern: pattern, required: true}] if !required { diff --git a/server.go b/server.go index 92ebda7..71b56cc 100644 --- a/server.go +++ b/server.go @@ -83,6 +83,9 @@ func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if ok := srv.handleError(req.Allowed(r.Method), w); !ok { return } + if ok := srv.handleError(req.Url(*r.URL), w); !ok { + return + } if ok := srv.handleError(req.Header(Header(r.Header)), w); !ok { return }