diff --git a/common.go b/common.go index 1e53492..6b21840 100644 --- a/common.go +++ b/common.go @@ -3,11 +3,12 @@ package router import ( "fmt" "reflect" + "strings" ) type genericValues map[string][]string -func (gv genericValues) Parse(data any, tag string) (err error) { +func (gv genericValues) Marshal(data any, tag string) (err error) { if gv == nil { gv = make(genericValues) } @@ -44,3 +45,50 @@ func (gv genericValues) Parse(data any, tag string) (err error) { } return } + +func (gv genericValues) Unmarshal(data any, tag string) (err error) { + if gv == nil { + err = fmt.Errorf("expected map to exist") + return + } + d := reflect.ValueOf(data) + if d.Kind() != reflect.Pointer || d.Kind() != reflect.Interface { + err = fmt.Errorf("expected pointer for data") + return + } + d = d.Elem() + if d.Kind() != reflect.Struct { + err = fmt.Errorf("expected struct input for data") + return + } + for key, value := range gv { + var v reflect.Value + if len(value) == 1 { + v = reflect.ValueOf(value[0]) + } else { + v = reflect.ValueOf(value) + } + for i := 0; i < d.NumField(); i++ { + k, ok := d.Type().Field(i).Tag.Lookup(tag) + if !ok || k != key { + continue + } + val := d.Field(i) + if val.Kind() == reflect.Pointer || val.Kind() == reflect.Interface { + val = val.Elem() + } + if !val.CanSet() { + err = fmt.Errorf("could not set value for %s", d.Type().Field(i).Name) + } + switch val.Kind() { + case reflect.String: + val.Set(reflect.ValueOf(strings.Join(value, "; "))) + default: + if v.CanConvert(val.Type()) { + val.Set(v.Convert(val.Type())) + } + } + } + } + return +} diff --git a/default.go b/default.go index 67396ad..d8797d0 100644 --- a/default.go +++ b/default.go @@ -2,7 +2,6 @@ package router import ( "encoding/json" - "fmt" "io" "net/http" "net/url" @@ -10,62 +9,62 @@ import ( type DefaultError uint32 -type DefaultErrorHandler struct { - DefaultError -} - -func (*DefaultErrorHandler) ErrorResponse(errorResponse ErrorResponse) ErrorResponse { - return DefaultErrorHandler{errorResponse.(DefaultError)} -} - const ( - DefaultOk DefaultError = http.StatusOK - DefaultErrorNotImplemented DefaultError = http.StatusNotImplemented - DefaultErrorMethodNotAllowed DefaultError = http.StatusMethodNotAllowed - DefaultErrorBadRequest DefaultError = http.StatusBadRequest - DefaultErrorUnauthorized DefaultError = http.StatusUnauthorized - DefaultErrorServerError DefaultError = http.StatusInternalServerError + DefaultErrorBadRequest DefaultError = http.StatusBadRequest + DefaultErrorUnauthorized DefaultError = http.StatusUnauthorized + DefaultErrorMethodNotAllowed DefaultError = http.StatusMethodNotAllowed + DefaultErrorInternalServerError DefaultError = http.StatusInternalServerError + DefaultErrorNotImplemented DefaultError = http.StatusNotImplemented ) -func (e DefaultError) Ok() (ok bool) { - return e == DefaultOk +func (e DefaultError) Error(Error) Error { + return e } -func (e DefaultError) HttpStatus() (code int) { +func (e DefaultError) Status() (code int) { return int(e) } func (e DefaultError) String() (out string) { switch e { - case DefaultOk: - out = "ok" - case DefaultErrorNotImplemented: - out = "server not implemented" - case DefaultErrorMethodNotAllowed: - out = "method not allowed" case DefaultErrorBadRequest: out = "bad request" case DefaultErrorUnauthorized: out = "user unauthorized" - case DefaultErrorServerError: + case DefaultErrorMethodNotAllowed: + out = "method not allowed" + case DefaultErrorInternalServerError: out = "internal server error" + case DefaultErrorNotImplemented: + out = "server not implemented" default: out = "unhandled error" } return } -func (e DefaultError) Error() (out string) { - return fmt.Sprintf("%s (%d)", e.String(), e.HttpStatus()) +func (e DefaultError) Response() *Response { + body, err := json.Marshal(struct{ Error string }{Error: e.String()}) + if err != nil { + panic(err) + } + return &Response{ + Status: e.Status(), + Header: Header{"Content-Type": []string{"text/javascript", "charset=utf-8"}}, + Body: body, + } } -func (e DefaultError) BodyBytes() (body []byte) { - body, _ = json.Marshal(struct{ Error string }{Error: e.String()}) - return +type DefaultErrorHandler struct { + DefaultError +} + +func (eh *DefaultErrorHandler) Error(e Error) Error { + return &DefaultErrorHandler{DefaultError: e.(DefaultError)} } type DefaultRequestBuilder struct { - errorHandler ErrorHandler + errorHandler Error allowedMethods *[]string url *url.URL header struct { @@ -82,167 +81,132 @@ type DefaultRequestBuilder struct { } } +// /* Leave commented to require services to create their own New method */ +// var _ RequestBuilder = (*DefaultRequestBuilder)(nil) +// func (*DefaultRequestBuilder) New() RequestBuilder { +// return NewDefaultRequestBuilder() +// } + +func (*DefaultRequestBuilder) mustEmbedDefaultRequestBuilder() {} + func NewDefaultRequestBuilder() *DefaultRequestBuilder { return &DefaultRequestBuilder{} } -func (rb *DefaultRequestBuilder) SetErrorHandler(errorHandler ErrorHandler) *DefaultRequestBuilder { - rb.errorHandler = errorHandler - return rb +func (rqb *DefaultRequestBuilder) SetErrorHandler(errorHandler Error) *DefaultRequestBuilder { + rqb.errorHandler = errorHandler + return rqb } -func (rb *DefaultRequestBuilder) SetAllowedMethods(allowedMethods *[]string) *DefaultRequestBuilder { - rb.allowedMethods = allowedMethods - return rb +func (rqb *DefaultRequestBuilder) SetAllowedMethods(allowedMethods *[]string) *DefaultRequestBuilder { + rqb.allowedMethods = allowedMethods + return rqb } -func (rb *DefaultRequestBuilder) SetUrl(url *url.URL) *DefaultRequestBuilder { - rb.url = url - return rb +func (rqb *DefaultRequestBuilder) SetUrl(url *url.URL) *DefaultRequestBuilder { + rqb.url = url + return rqb } -func (rb *DefaultRequestBuilder) SetHeader(parsed *Header, fields header) *DefaultRequestBuilder { - rb.header.Header = parsed - rb.header.fields = fields - return rb +func (rqb *DefaultRequestBuilder) SetHeader(parsed *Header, fields header) *DefaultRequestBuilder { + rqb.header.Header = parsed + rqb.header.fields = fields + return rqb } -func (rb *DefaultRequestBuilder) SetValues(parsed *Values, fields values) *DefaultRequestBuilder { - rb.values.Values = parsed - rb.values.fields = fields - return rb +func (rqb *DefaultRequestBuilder) SetValues(parsed *Values, fields values) *DefaultRequestBuilder { + rqb.values.Values = parsed + rqb.values.fields = fields + return rqb } -func (rb *DefaultRequestBuilder) SetBody(parsed *Body, fields body) *DefaultRequestBuilder { - rb.body.Body = parsed - rb.body.fields = fields - return rb +func (rqb *DefaultRequestBuilder) SetBody(parsed *Body, fields body) *DefaultRequestBuilder { + rqb.body.Body = parsed + rqb.body.fields = fields + return rqb } -func (rb *DefaultRequestBuilder) SetDefaults() *DefaultRequestBuilder { - if rb.errorHandler == nil { - rb.errorHandler = &DefaultErrorHandler{} +func (rqb *DefaultRequestBuilder) SetDefaults() *DefaultRequestBuilder { + if rqb.errorHandler == nil { + rqb.errorHandler = &DefaultErrorHandler{} } - if rb.allowedMethods == nil { + if rqb.allowedMethods == nil { amd := make([]string, 0) - rb.allowedMethods = &amd + rqb.allowedMethods = &amd } - if rb.url == nil { + if rqb.url == nil { u := url.URL{} - rb.url = &u + rqb.url = &u } - if rb.header.Header == nil || rb.header.fields == nil { - hd := make(Header) - rb.header.Header = &hd + if rqb.header.Header == nil || rqb.header.fields == nil { + hd := Header{"Content-Type": []string{"text/plain", "charset=utf-8"}} + rqb.header.Header = &hd hfd := struct{ Header }{Header: hd} - rb.header.fields = &hfd + rqb.header.fields = &hfd } - if rb.values.Values == nil || rb.values.fields == nil { + if rqb.values.Values == nil || rqb.values.fields == nil { vd := make(Values) - rb.values.Values = &vd + rqb.values.Values = &vd vfd := struct{ Values }{Values: vd} - rb.values.fields = &vfd + rqb.values.fields = &vfd } - if rb.body.Body == nil || rb.body.fields == nil { + if rqb.body.Body == nil || rqb.body.fields == nil { bd := make(Body, 0) - rb.body.Body = &bd + rqb.body.Body = &bd bfd := struct{ Body }{Body: bd} - rb.body.fields = &bfd + rqb.body.fields = &bfd } - return rb + return rqb } -func (rb *DefaultRequestBuilder) Allowed(method string) (errRes ErrorResponse) { +func (rqb *DefaultRequestBuilder) Allowed(method string) (e Error) { var ok bool - for _, m := range *rb.allowedMethods { + for _, m := range *rqb.allowedMethods { if m == method { ok = true } } if !ok { - return rb.errorHandler.ErrorResponse(DefaultErrorMethodNotAllowed) + return rqb.errorHandler.Error(DefaultErrorMethodNotAllowed) + } + return +} + +func (rqb *DefaultRequestBuilder) Url(url url.URL) (e Error) { + *rqb.url = url + return +} + +func (rqb *DefaultRequestBuilder) Header(header Header) (e Error) { + *rqb.header.Header = header + err := rqb.header.Header.Unmarshal(rqb.header.fields) + if err != nil { + return rqb.errorHandler.Error(DefaultErrorBadRequest) + } + return +} + +func (rqb *DefaultRequestBuilder) Body(body io.ReadCloser) (e Error) { + defer body.Close() + json.NewDecoder(body).Decode(rqb.body.fields) + return +} + +func (rqb *DefaultRequestBuilder) Values(values Values) (e Error) { + *rqb.values.Values = values + err := rqb.values.Values.Unmarshal(rqb.values.fields) + if err != nil { + return rqb.errorHandler.Error(DefaultErrorBadRequest) + } + return +} + +func (rqb *DefaultRequestBuilder) Request() (req *request) { + req = &request{ + Url: *rqb.url, + Header: *rqb.header.Header, + Values: *rqb.values.Values, + Body: *rqb.body.Body, } - return rb.errorHandler.ErrorResponse(DefaultOk) -} - -func (rb *DefaultRequestBuilder) Url(url url.URL) (errRes ErrorResponse) { - *rb.url = url - return rb.errorHandler.ErrorResponse(DefaultOk) -} - -func (rb *DefaultRequestBuilder) Header(header Header) (errRes ErrorResponse) { - err := rb.header.Header.Parse(header) - if err != nil { - return rb.errorHandler.ErrorResponse(DefaultErrorBadRequest) - } - return rb.errorHandler.ErrorResponse(DefaultOk) -} - -func (rb *DefaultRequestBuilder) ReadBody(body io.ReadCloser) (errRes ErrorResponse) { - defer body.Close() - json.NewDecoder(body).Decode(rb.body.fields) - return rb.errorHandler.ErrorResponse(DefaultOk) -} - -func (rb *DefaultRequestBuilder) Values(values Values) (errRes ErrorResponse) { - err := rb.values.Values.Parse(values) - if err != nil { - return rb.errorHandler.ErrorResponse(DefaultErrorBadRequest) - } - return rb.errorHandler.ErrorResponse(DefaultOk) -} - -type DefaultResponse struct { - header struct { - *Header - fields header - } - body struct { - *Body - fields body - } -} - -func NewDefaultResponse() *DefaultResponse { - return &DefaultResponse{} -} - -func (r *DefaultResponse) SetHeader(parsed *Header, fields header) *DefaultResponse { - r.header.Header = parsed - r.header.fields = fields - return r -} - -func (r *DefaultResponse) SetBody(parsed *Body, fields body) *DefaultResponse { - r.body.Body = parsed - r.body.fields = fields - return r -} - -func (r *DefaultResponse) SetDefaults() *DefaultResponse { - if r.header.Header == nil || r.header.fields == nil { - hd := make(Header) - r.header.Header = &hd - hfd := struct{ Header }{Header: hd} - r.header.fields = &hfd - } - if r.body.Body == nil || r.body.fields == nil { - bd := make(Body, 0) - r.body.Body = &bd - bfd := struct{ Body }{Body: bd} - r.body.fields = &bfd - } - return r -} - -func (r *DefaultResponse) Header() (header Header) { - if r.header.Header == nil { - r.header.Header.Parse(r.header.fields) - } - return *r.header.Header -} - -func (r *DefaultResponse) BodyBytes() (body []byte) { - body, _ = json.Marshal(r.body.fields) return } diff --git a/header.go b/header.go index fbe12c5..953d442 100644 --- a/header.go +++ b/header.go @@ -12,8 +12,12 @@ type Header map[string][]string func (Header) mustEmbedHeader() {} -func (h Header) Parse(data any) error { - return genericValues(h).Parse(data, "header") +func (h Header) Marshal(data any) error { + return genericValues(h).Marshal(data, "header") +} + +func (h Header) Unmarshal(data any) error { + return genericValues(h).Unmarshal(data, "header") } func (h Header) Get(key string) (value string) { diff --git a/prototype.go b/prototype.go deleted file mode 100644 index 8039b2e..0000000 --- a/prototype.go +++ /dev/null @@ -1,47 +0,0 @@ -package router - -import ( - "io" - "net/url" -) - -type RequestBuilder interface { - Prototype() (req PrototypeRequestBuilder) - Allowed(method string) (errRes ErrorResponse) - Url(url url.URL) (errRes ErrorResponse) - Header(header Header) (errRes ErrorResponse) - ReadBody(body io.ReadCloser) (errRes ErrorResponse) - Values(values Values) (errRes ErrorResponse) -} - -type PrototypeRequestBuilder interface { - RequestBuilder() (req RequestBuilder) -} - -type Response interface { - Prototype() (req PrototypeResponse) - Header() (header Header) - BodyBytes() (body []byte) -} - -type PrototypeResponse interface { - Response() (res Response) -} - -type ErrorResponse interface { - Ok() (ok bool) - HttpStatus() (code int) - BodyBytes() (body []byte) - String() (out string) - Error() (out string) -} - -type ErrorHandler interface { - ErrorResponse - ErrorResponse(errorResponse ErrorResponse) ErrorResponse -} - -type Prototype struct { - PrototypeRequestBuilder - PrototypeResponse -} diff --git a/router.go b/router.go index f8fb855..c0125ef 100644 --- a/router.go +++ b/router.go @@ -13,7 +13,7 @@ type Route struct { type Router struct { *http.ServeMux - routes map[Route]*server + routes map[Route]http.Handler } func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) { @@ -22,7 +22,7 @@ func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) { } ro = &Router{ ServeMux: mux, - routes: make(map[Route]*server), + routes: make(map[Route]http.Handler), } ro.AddRequiredRoutes(requiredRoutes) return @@ -37,7 +37,7 @@ func (ro *Router) AddRequiredRoutes(requiredRoutes []string) { } } -func (ro *Router) Register(pattern string, server *server) (err error) { +func (ro *Router) Register(pattern string, server http.Handler) (err error) { if !strings.HasPrefix(pattern, "/") { return fmt.Errorf("missing preceding slash in pattern (%s)", pattern) } diff --git a/server.go b/server.go index 71b56cc..1e059ae 100644 --- a/server.go +++ b/server.go @@ -1,116 +1,163 @@ package router import ( + "io" "net/http" + "net/url" "somehole.com/common/log" ) -type stage uint8 - -const ( - pre stage = iota - main - post - numStages -) - -type ServeFunc func(req RequestBuilder, res Response) (errRes ErrorResponse) - -type server struct { - prototype Prototype - logger log.Logger - serve [numStages][]ServeFunc +type Response struct { + Status int + Header Header + Body Body } -func NewServer[RB RequestBuilder, R Response](serve func(RB) (R, ErrorResponse)) (srv *server) { - var ( - requestBuilder RB - response R - ) - srv = &server{ - prototype: Prototype{ - PrototypeRequestBuilder: requestBuilder.Prototype(), - PrototypeResponse: response.Prototype(), - }, - serve: [numStages][]ServeFunc{}, - } - srv.serve[main] = append(srv.serve[main], func(req RequestBuilder, res Response) (errRes ErrorResponse) { - res, errRes = serve(req.(RB)) - return - }) - return srv +type Error interface { + Error(e Error) (err Error) + Status() (code int) + String() (out string) + Response() (res *Response) } -func (srv *server) SetLogger(logger log.Logger) *server { - srv.logger = logger - return srv +type request struct { + Url url.URL + Header Header + Values Values + Body Body } -func (srv *server) addServeFunc(when stage, serve ServeFunc) *server { - if srv.serve[when] == nil { - srv.serve[when] = make([]ServeFunc, 0) - } - srv.serve[when] = append(srv.serve[when], serve) - return srv +type Request[RQB RequestBuilder] struct { + request } -func (srv *server) PreServeFunc(serve ServeFunc) *server { - return srv.addServeFunc(pre, serve) +func NewRequest[RQB RequestBuilder](rqb RQB) *Request[RQB] { + return &Request[RQB]{*rqb.Request()} } -func (srv *server) AddServeFunc(serve ServeFunc) *server { - return srv.addServeFunc(main, serve) +func (*Request[RQB]) RequestBuilder() RequestBuilder { + var rqb RQB + return rqb.New() } -func (srv *server) PostServeFunc(serve ServeFunc) *server { - return srv.addServeFunc(post, serve) +type RequestBuilder interface { + mustEmbedDefaultRequestBuilder() + New() (rqb RequestBuilder) + Allowed(method string) (err Error) + Url(url url.URL) (err Error) + Header(header Header) (err Error) + Values(values Values) (err Error) + Body(body io.ReadCloser) (err Error) + Request() *request } -func (srv *server) handleError(errRes ErrorResponse, w http.ResponseWriter) (ok bool) { - if !errRes.Ok() { - w.WriteHeader(errRes.HttpStatus()) - w.Write(errRes.BodyBytes()) - srv.logger.Logf(log.LevelError, errRes.String()) +type writer struct { + http.ResponseWriter + log.Logger +} + +func (w *writer) handleError(err Error) (ok bool) { + if err != nil { + res := err.Response() + for key, value := range res.Header { + for _, v := range value { + w.Header().Add(key, v) + } + } + w.WriteHeader(res.Status) + w.Write(res.Body) + w.Logf(log.LevelError, err.String()) return false } return true } -func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - p := srv.prototype - req := p.PrototypeRequestBuilder.RequestBuilder() - if ok := srv.handleError(req.Allowed(r.Method), w); !ok { +type serveStage uint8 + +const ( + servePre serveStage = iota + serveMain + servePost + numServeStages +) + +type ServeFunc[RQB RequestBuilder] func(req *Request[RQB]) (res *Response, err Error) + +type server[RQB RequestBuilder] struct { + logger log.Logger + serve [numServeStages][]ServeFunc[RQB] +} + +func NewServer[RQB RequestBuilder](serve ServeFunc[RQB]) (srv *server[RQB]) { + srv = &server[RQB]{ + serve: [numServeStages][]ServeFunc[RQB]{}, + } + srv.serve[serveMain] = append(srv.serve[serveMain], serve) + return srv +} + +func (srv *server[RQB]) SetLogger(logger log.Logger) *server[RQB] { + srv.logger = logger + return srv +} + +func (srv *server[RQB]) addServeFunc(when serveStage, serve ServeFunc[RQB]) *server[RQB] { + if srv.serve[when] == nil { + srv.serve[when] = make([]ServeFunc[RQB], 0) + } + srv.serve[when] = append(srv.serve[when], serve) + return srv +} + +func (srv *server[RQB]) PreServeFunc(serve ServeFunc[RQB]) *server[RQB] { + return srv.addServeFunc(servePre, serve) +} + +func (srv *server[RQB]) AddServeFunc(serve ServeFunc[RQB]) *server[RQB] { + return srv.addServeFunc(serveMain, serve) +} + +func (srv *server[RQB]) PostServeFunc(serve ServeFunc[RQB]) *server[RQB] { + return srv.addServeFunc(servePost, serve) +} + +func (srv *server[RQB]) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var rqb RQB + rqb = rqb.New().(RQB) + wr := writer{ResponseWriter: w, Logger: srv.logger} + if ok := wr.handleError(rqb.Allowed(r.Method)); !ok { return } - if ok := srv.handleError(req.Url(*r.URL), w); !ok { + if ok := wr.handleError(rqb.Url(*r.URL)); !ok { return } - if ok := srv.handleError(req.Header(Header(r.Header)), w); !ok { + if ok := wr.handleError(rqb.Header(Header(r.Header))); !ok { return } - if ok := srv.handleError(req.ReadBody(r.Body), w); !ok { + if ok := wr.handleError(rqb.Body(r.Body)); !ok { return } r.ParseForm() - if ok := srv.handleError(req.Values(Values(r.Form)), w); !ok { + if ok := wr.handleError(rqb.Values(Values(r.Form))); !ok { return } - res := p.PrototypeResponse.Response() for _, stage := range srv.serve { for _, s := range stage { - if ok := srv.handleError(s(req, res), w); !ok { + res, err := s(NewRequest(rqb)) + if ok := wr.handleError(err); !ok { return } - header := res.Header() - if header != nil { - for key, value := range res.Header() { + if res.Header != nil { + for key, value := range res.Header { for _, v := range value { w.Header().Add(key, v) } } } + if len(res.Body) > 0 { + wr.Write(res.Body) + } } } - w.Write(res.BodyBytes()) } diff --git a/values.go b/values.go index cb39a5c..810297d 100644 --- a/values.go +++ b/values.go @@ -8,6 +8,10 @@ type Values map[string][]string func (Values) mustEmbedValues() {} -func (v Values) Parse(data any) error { - return genericValues(v).Parse(data, "form") +func (v Values) Marshal(data any) error { + return genericValues(v).Marshal(data, "form") +} + +func (v Values) Unmarshal(data any) error { + return genericValues(v).Unmarshal(data, "form") }