Major rework

This commit is contained in:
some 2024-10-10 14:26:07 -04:00
parent 71b64c7cc8
commit fb6b5eed4b
Signed by: some
GPG Key ID: 65D0589220B9BFC8
7 changed files with 296 additions and 276 deletions

View File

@ -3,11 +3,12 @@ package router
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"strings"
) )
type genericValues map[string][]string type genericValues map[string][]string
func (gv genericValues) Parse(data any, tag string) (err error) { func (gv genericValues) Marshal(data any, tag string) (err error) {
if gv == nil { if gv == nil {
gv = make(genericValues) gv = make(genericValues)
} }
@ -44,3 +45,50 @@ func (gv genericValues) Parse(data any, tag string) (err error) {
} }
return return
} }
func (gv genericValues) Unmarshal(data any, tag string) (err error) {
if gv == nil {
err = fmt.Errorf("expected map to exist")
return
}
d := reflect.ValueOf(data)
if d.Kind() != reflect.Pointer || d.Kind() != reflect.Interface {
err = fmt.Errorf("expected pointer for data")
return
}
d = d.Elem()
if d.Kind() != reflect.Struct {
err = fmt.Errorf("expected struct input for data")
return
}
for key, value := range gv {
var v reflect.Value
if len(value) == 1 {
v = reflect.ValueOf(value[0])
} else {
v = reflect.ValueOf(value)
}
for i := 0; i < d.NumField(); i++ {
k, ok := d.Type().Field(i).Tag.Lookup(tag)
if !ok || k != key {
continue
}
val := d.Field(i)
if val.Kind() == reflect.Pointer || val.Kind() == reflect.Interface {
val = val.Elem()
}
if !val.CanSet() {
err = fmt.Errorf("could not set value for %s", d.Type().Field(i).Name)
}
switch val.Kind() {
case reflect.String:
val.Set(reflect.ValueOf(strings.Join(value, "; ")))
default:
if v.CanConvert(val.Type()) {
val.Set(v.Convert(val.Type()))
}
}
}
}
return
}

View File

@ -2,7 +2,6 @@ package router
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -10,62 +9,62 @@ import (
type DefaultError uint32 type DefaultError uint32
type DefaultErrorHandler struct {
DefaultError
}
func (*DefaultErrorHandler) ErrorResponse(errorResponse ErrorResponse) ErrorResponse {
return DefaultErrorHandler{errorResponse.(DefaultError)}
}
const ( const (
DefaultOk DefaultError = http.StatusOK DefaultErrorBadRequest DefaultError = http.StatusBadRequest
DefaultErrorNotImplemented DefaultError = http.StatusNotImplemented DefaultErrorUnauthorized DefaultError = http.StatusUnauthorized
DefaultErrorMethodNotAllowed DefaultError = http.StatusMethodNotAllowed DefaultErrorMethodNotAllowed DefaultError = http.StatusMethodNotAllowed
DefaultErrorBadRequest DefaultError = http.StatusBadRequest DefaultErrorInternalServerError DefaultError = http.StatusInternalServerError
DefaultErrorUnauthorized DefaultError = http.StatusUnauthorized DefaultErrorNotImplemented DefaultError = http.StatusNotImplemented
DefaultErrorServerError DefaultError = http.StatusInternalServerError
) )
func (e DefaultError) Ok() (ok bool) { func (e DefaultError) Error(Error) Error {
return e == DefaultOk return e
} }
func (e DefaultError) HttpStatus() (code int) { func (e DefaultError) Status() (code int) {
return int(e) return int(e)
} }
func (e DefaultError) String() (out string) { func (e DefaultError) String() (out string) {
switch e { switch e {
case DefaultOk:
out = "ok"
case DefaultErrorNotImplemented:
out = "server not implemented"
case DefaultErrorMethodNotAllowed:
out = "method not allowed"
case DefaultErrorBadRequest: case DefaultErrorBadRequest:
out = "bad request" out = "bad request"
case DefaultErrorUnauthorized: case DefaultErrorUnauthorized:
out = "user unauthorized" out = "user unauthorized"
case DefaultErrorServerError: case DefaultErrorMethodNotAllowed:
out = "method not allowed"
case DefaultErrorInternalServerError:
out = "internal server error" out = "internal server error"
case DefaultErrorNotImplemented:
out = "server not implemented"
default: default:
out = "unhandled error" out = "unhandled error"
} }
return return
} }
func (e DefaultError) Error() (out string) { func (e DefaultError) Response() *Response {
return fmt.Sprintf("%s (%d)", e.String(), e.HttpStatus()) body, err := json.Marshal(struct{ Error string }{Error: e.String()})
if err != nil {
panic(err)
}
return &Response{
Status: e.Status(),
Header: Header{"Content-Type": []string{"text/javascript", "charset=utf-8"}},
Body: body,
}
} }
func (e DefaultError) BodyBytes() (body []byte) { type DefaultErrorHandler struct {
body, _ = json.Marshal(struct{ Error string }{Error: e.String()}) DefaultError
return }
func (eh *DefaultErrorHandler) Error(e Error) Error {
return &DefaultErrorHandler{DefaultError: e.(DefaultError)}
} }
type DefaultRequestBuilder struct { type DefaultRequestBuilder struct {
errorHandler ErrorHandler errorHandler Error
allowedMethods *[]string allowedMethods *[]string
url *url.URL url *url.URL
header struct { header struct {
@ -82,167 +81,132 @@ type DefaultRequestBuilder struct {
} }
} }
// /* Leave commented to require services to create their own New method */
// var _ RequestBuilder = (*DefaultRequestBuilder)(nil)
// func (*DefaultRequestBuilder) New() RequestBuilder {
// return NewDefaultRequestBuilder()
// }
func (*DefaultRequestBuilder) mustEmbedDefaultRequestBuilder() {}
func NewDefaultRequestBuilder() *DefaultRequestBuilder { func NewDefaultRequestBuilder() *DefaultRequestBuilder {
return &DefaultRequestBuilder{} return &DefaultRequestBuilder{}
} }
func (rb *DefaultRequestBuilder) SetErrorHandler(errorHandler ErrorHandler) *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetErrorHandler(errorHandler Error) *DefaultRequestBuilder {
rb.errorHandler = errorHandler rqb.errorHandler = errorHandler
return rb return rqb
} }
func (rb *DefaultRequestBuilder) SetAllowedMethods(allowedMethods *[]string) *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetAllowedMethods(allowedMethods *[]string) *DefaultRequestBuilder {
rb.allowedMethods = allowedMethods rqb.allowedMethods = allowedMethods
return rb return rqb
} }
func (rb *DefaultRequestBuilder) SetUrl(url *url.URL) *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetUrl(url *url.URL) *DefaultRequestBuilder {
rb.url = url rqb.url = url
return rb return rqb
} }
func (rb *DefaultRequestBuilder) SetHeader(parsed *Header, fields header) *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetHeader(parsed *Header, fields header) *DefaultRequestBuilder {
rb.header.Header = parsed rqb.header.Header = parsed
rb.header.fields = fields rqb.header.fields = fields
return rb return rqb
} }
func (rb *DefaultRequestBuilder) SetValues(parsed *Values, fields values) *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetValues(parsed *Values, fields values) *DefaultRequestBuilder {
rb.values.Values = parsed rqb.values.Values = parsed
rb.values.fields = fields rqb.values.fields = fields
return rb return rqb
} }
func (rb *DefaultRequestBuilder) SetBody(parsed *Body, fields body) *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetBody(parsed *Body, fields body) *DefaultRequestBuilder {
rb.body.Body = parsed rqb.body.Body = parsed
rb.body.fields = fields rqb.body.fields = fields
return rb return rqb
} }
func (rb *DefaultRequestBuilder) SetDefaults() *DefaultRequestBuilder { func (rqb *DefaultRequestBuilder) SetDefaults() *DefaultRequestBuilder {
if rb.errorHandler == nil { if rqb.errorHandler == nil {
rb.errorHandler = &DefaultErrorHandler{} rqb.errorHandler = &DefaultErrorHandler{}
} }
if rb.allowedMethods == nil { if rqb.allowedMethods == nil {
amd := make([]string, 0) amd := make([]string, 0)
rb.allowedMethods = &amd rqb.allowedMethods = &amd
} }
if rb.url == nil { if rqb.url == nil {
u := url.URL{} u := url.URL{}
rb.url = &u rqb.url = &u
} }
if rb.header.Header == nil || rb.header.fields == nil { if rqb.header.Header == nil || rqb.header.fields == nil {
hd := make(Header) hd := Header{"Content-Type": []string{"text/plain", "charset=utf-8"}}
rb.header.Header = &hd rqb.header.Header = &hd
hfd := struct{ Header }{Header: hd} hfd := struct{ Header }{Header: hd}
rb.header.fields = &hfd rqb.header.fields = &hfd
} }
if rb.values.Values == nil || rb.values.fields == nil { if rqb.values.Values == nil || rqb.values.fields == nil {
vd := make(Values) vd := make(Values)
rb.values.Values = &vd rqb.values.Values = &vd
vfd := struct{ Values }{Values: vd} vfd := struct{ Values }{Values: vd}
rb.values.fields = &vfd rqb.values.fields = &vfd
} }
if rb.body.Body == nil || rb.body.fields == nil { if rqb.body.Body == nil || rqb.body.fields == nil {
bd := make(Body, 0) bd := make(Body, 0)
rb.body.Body = &bd rqb.body.Body = &bd
bfd := struct{ Body }{Body: bd} bfd := struct{ Body }{Body: bd}
rb.body.fields = &bfd rqb.body.fields = &bfd
} }
return rb return rqb
} }
func (rb *DefaultRequestBuilder) Allowed(method string) (errRes ErrorResponse) { func (rqb *DefaultRequestBuilder) Allowed(method string) (e Error) {
var ok bool var ok bool
for _, m := range *rb.allowedMethods { for _, m := range *rqb.allowedMethods {
if m == method { if m == method {
ok = true ok = true
} }
} }
if !ok { if !ok {
return rb.errorHandler.ErrorResponse(DefaultErrorMethodNotAllowed) return rqb.errorHandler.Error(DefaultErrorMethodNotAllowed)
}
return
}
func (rqb *DefaultRequestBuilder) Url(url url.URL) (e Error) {
*rqb.url = url
return
}
func (rqb *DefaultRequestBuilder) Header(header Header) (e Error) {
*rqb.header.Header = header
err := rqb.header.Header.Unmarshal(rqb.header.fields)
if err != nil {
return rqb.errorHandler.Error(DefaultErrorBadRequest)
}
return
}
func (rqb *DefaultRequestBuilder) Body(body io.ReadCloser) (e Error) {
defer body.Close()
json.NewDecoder(body).Decode(rqb.body.fields)
return
}
func (rqb *DefaultRequestBuilder) Values(values Values) (e Error) {
*rqb.values.Values = values
err := rqb.values.Values.Unmarshal(rqb.values.fields)
if err != nil {
return rqb.errorHandler.Error(DefaultErrorBadRequest)
}
return
}
func (rqb *DefaultRequestBuilder) Request() (req *request) {
req = &request{
Url: *rqb.url,
Header: *rqb.header.Header,
Values: *rqb.values.Values,
Body: *rqb.body.Body,
} }
return rb.errorHandler.ErrorResponse(DefaultOk)
}
func (rb *DefaultRequestBuilder) Url(url url.URL) (errRes ErrorResponse) {
*rb.url = url
return rb.errorHandler.ErrorResponse(DefaultOk)
}
func (rb *DefaultRequestBuilder) Header(header Header) (errRes ErrorResponse) {
err := rb.header.Header.Parse(header)
if err != nil {
return rb.errorHandler.ErrorResponse(DefaultErrorBadRequest)
}
return rb.errorHandler.ErrorResponse(DefaultOk)
}
func (rb *DefaultRequestBuilder) ReadBody(body io.ReadCloser) (errRes ErrorResponse) {
defer body.Close()
json.NewDecoder(body).Decode(rb.body.fields)
return rb.errorHandler.ErrorResponse(DefaultOk)
}
func (rb *DefaultRequestBuilder) Values(values Values) (errRes ErrorResponse) {
err := rb.values.Values.Parse(values)
if err != nil {
return rb.errorHandler.ErrorResponse(DefaultErrorBadRequest)
}
return rb.errorHandler.ErrorResponse(DefaultOk)
}
type DefaultResponse struct {
header struct {
*Header
fields header
}
body struct {
*Body
fields body
}
}
func NewDefaultResponse() *DefaultResponse {
return &DefaultResponse{}
}
func (r *DefaultResponse) SetHeader(parsed *Header, fields header) *DefaultResponse {
r.header.Header = parsed
r.header.fields = fields
return r
}
func (r *DefaultResponse) SetBody(parsed *Body, fields body) *DefaultResponse {
r.body.Body = parsed
r.body.fields = fields
return r
}
func (r *DefaultResponse) SetDefaults() *DefaultResponse {
if r.header.Header == nil || r.header.fields == nil {
hd := make(Header)
r.header.Header = &hd
hfd := struct{ Header }{Header: hd}
r.header.fields = &hfd
}
if r.body.Body == nil || r.body.fields == nil {
bd := make(Body, 0)
r.body.Body = &bd
bfd := struct{ Body }{Body: bd}
r.body.fields = &bfd
}
return r
}
func (r *DefaultResponse) Header() (header Header) {
if r.header.Header == nil {
r.header.Header.Parse(r.header.fields)
}
return *r.header.Header
}
func (r *DefaultResponse) BodyBytes() (body []byte) {
body, _ = json.Marshal(r.body.fields)
return return
} }

View File

@ -12,8 +12,12 @@ type Header map[string][]string
func (Header) mustEmbedHeader() {} func (Header) mustEmbedHeader() {}
func (h Header) Parse(data any) error { func (h Header) Marshal(data any) error {
return genericValues(h).Parse(data, "header") return genericValues(h).Marshal(data, "header")
}
func (h Header) Unmarshal(data any) error {
return genericValues(h).Unmarshal(data, "header")
} }
func (h Header) Get(key string) (value string) { func (h Header) Get(key string) (value string) {

View File

@ -1,47 +0,0 @@
package router
import (
"io"
"net/url"
)
type RequestBuilder interface {
Prototype() (req PrototypeRequestBuilder)
Allowed(method string) (errRes ErrorResponse)
Url(url url.URL) (errRes ErrorResponse)
Header(header Header) (errRes ErrorResponse)
ReadBody(body io.ReadCloser) (errRes ErrorResponse)
Values(values Values) (errRes ErrorResponse)
}
type PrototypeRequestBuilder interface {
RequestBuilder() (req RequestBuilder)
}
type Response interface {
Prototype() (req PrototypeResponse)
Header() (header Header)
BodyBytes() (body []byte)
}
type PrototypeResponse interface {
Response() (res Response)
}
type ErrorResponse interface {
Ok() (ok bool)
HttpStatus() (code int)
BodyBytes() (body []byte)
String() (out string)
Error() (out string)
}
type ErrorHandler interface {
ErrorResponse
ErrorResponse(errorResponse ErrorResponse) ErrorResponse
}
type Prototype struct {
PrototypeRequestBuilder
PrototypeResponse
}

View File

@ -13,7 +13,7 @@ type Route struct {
type Router struct { type Router struct {
*http.ServeMux *http.ServeMux
routes map[Route]*server routes map[Route]http.Handler
} }
func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) { func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) {
@ -22,7 +22,7 @@ func NewRouter(mux *http.ServeMux, requiredRoutes []string) (ro *Router) {
} }
ro = &Router{ ro = &Router{
ServeMux: mux, ServeMux: mux,
routes: make(map[Route]*server), routes: make(map[Route]http.Handler),
} }
ro.AddRequiredRoutes(requiredRoutes) ro.AddRequiredRoutes(requiredRoutes)
return return
@ -37,7 +37,7 @@ func (ro *Router) AddRequiredRoutes(requiredRoutes []string) {
} }
} }
func (ro *Router) Register(pattern string, server *server) (err error) { func (ro *Router) Register(pattern string, server http.Handler) (err error) {
if !strings.HasPrefix(pattern, "/") { if !strings.HasPrefix(pattern, "/") {
return fmt.Errorf("missing preceding slash in pattern (%s)", pattern) return fmt.Errorf("missing preceding slash in pattern (%s)", pattern)
} }

179
server.go
View File

@ -1,116 +1,163 @@
package router package router
import ( import (
"io"
"net/http" "net/http"
"net/url"
"somehole.com/common/log" "somehole.com/common/log"
) )
type stage uint8 type Response struct {
Status int
const ( Header Header
pre stage = iota Body Body
main
post
numStages
)
type ServeFunc func(req RequestBuilder, res Response) (errRes ErrorResponse)
type server struct {
prototype Prototype
logger log.Logger
serve [numStages][]ServeFunc
} }
func NewServer[RB RequestBuilder, R Response](serve func(RB) (R, ErrorResponse)) (srv *server) { type Error interface {
var ( Error(e Error) (err Error)
requestBuilder RB Status() (code int)
response R String() (out string)
) Response() (res *Response)
srv = &server{
prototype: Prototype{
PrototypeRequestBuilder: requestBuilder.Prototype(),
PrototypeResponse: response.Prototype(),
},
serve: [numStages][]ServeFunc{},
}
srv.serve[main] = append(srv.serve[main], func(req RequestBuilder, res Response) (errRes ErrorResponse) {
res, errRes = serve(req.(RB))
return
})
return srv
} }
func (srv *server) SetLogger(logger log.Logger) *server { type request struct {
srv.logger = logger Url url.URL
return srv Header Header
Values Values
Body Body
} }
func (srv *server) addServeFunc(when stage, serve ServeFunc) *server { type Request[RQB RequestBuilder] struct {
if srv.serve[when] == nil { request
srv.serve[when] = make([]ServeFunc, 0)
}
srv.serve[when] = append(srv.serve[when], serve)
return srv
} }
func (srv *server) PreServeFunc(serve ServeFunc) *server { func NewRequest[RQB RequestBuilder](rqb RQB) *Request[RQB] {
return srv.addServeFunc(pre, serve) return &Request[RQB]{*rqb.Request()}
} }
func (srv *server) AddServeFunc(serve ServeFunc) *server { func (*Request[RQB]) RequestBuilder() RequestBuilder {
return srv.addServeFunc(main, serve) var rqb RQB
return rqb.New()
} }
func (srv *server) PostServeFunc(serve ServeFunc) *server { type RequestBuilder interface {
return srv.addServeFunc(post, serve) mustEmbedDefaultRequestBuilder()
New() (rqb RequestBuilder)
Allowed(method string) (err Error)
Url(url url.URL) (err Error)
Header(header Header) (err Error)
Values(values Values) (err Error)
Body(body io.ReadCloser) (err Error)
Request() *request
} }
func (srv *server) handleError(errRes ErrorResponse, w http.ResponseWriter) (ok bool) { type writer struct {
if !errRes.Ok() { http.ResponseWriter
w.WriteHeader(errRes.HttpStatus()) log.Logger
w.Write(errRes.BodyBytes()) }
srv.logger.Logf(log.LevelError, errRes.String())
func (w *writer) handleError(err Error) (ok bool) {
if err != nil {
res := err.Response()
for key, value := range res.Header {
for _, v := range value {
w.Header().Add(key, v)
}
}
w.WriteHeader(res.Status)
w.Write(res.Body)
w.Logf(log.LevelError, err.String())
return false return false
} }
return true return true
} }
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { type serveStage uint8
p := srv.prototype
req := p.PrototypeRequestBuilder.RequestBuilder() const (
if ok := srv.handleError(req.Allowed(r.Method), w); !ok { servePre serveStage = iota
serveMain
servePost
numServeStages
)
type ServeFunc[RQB RequestBuilder] func(req *Request[RQB]) (res *Response, err Error)
type server[RQB RequestBuilder] struct {
logger log.Logger
serve [numServeStages][]ServeFunc[RQB]
}
func NewServer[RQB RequestBuilder](serve ServeFunc[RQB]) (srv *server[RQB]) {
srv = &server[RQB]{
serve: [numServeStages][]ServeFunc[RQB]{},
}
srv.serve[serveMain] = append(srv.serve[serveMain], serve)
return srv
}
func (srv *server[RQB]) SetLogger(logger log.Logger) *server[RQB] {
srv.logger = logger
return srv
}
func (srv *server[RQB]) addServeFunc(when serveStage, serve ServeFunc[RQB]) *server[RQB] {
if srv.serve[when] == nil {
srv.serve[when] = make([]ServeFunc[RQB], 0)
}
srv.serve[when] = append(srv.serve[when], serve)
return srv
}
func (srv *server[RQB]) PreServeFunc(serve ServeFunc[RQB]) *server[RQB] {
return srv.addServeFunc(servePre, serve)
}
func (srv *server[RQB]) AddServeFunc(serve ServeFunc[RQB]) *server[RQB] {
return srv.addServeFunc(serveMain, serve)
}
func (srv *server[RQB]) PostServeFunc(serve ServeFunc[RQB]) *server[RQB] {
return srv.addServeFunc(servePost, serve)
}
func (srv *server[RQB]) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var rqb RQB
rqb = rqb.New().(RQB)
wr := writer{ResponseWriter: w, Logger: srv.logger}
if ok := wr.handleError(rqb.Allowed(r.Method)); !ok {
return return
} }
if ok := srv.handleError(req.Url(*r.URL), w); !ok { if ok := wr.handleError(rqb.Url(*r.URL)); !ok {
return return
} }
if ok := srv.handleError(req.Header(Header(r.Header)), w); !ok { if ok := wr.handleError(rqb.Header(Header(r.Header))); !ok {
return return
} }
if ok := srv.handleError(req.ReadBody(r.Body), w); !ok { if ok := wr.handleError(rqb.Body(r.Body)); !ok {
return return
} }
r.ParseForm() r.ParseForm()
if ok := srv.handleError(req.Values(Values(r.Form)), w); !ok { if ok := wr.handleError(rqb.Values(Values(r.Form))); !ok {
return return
} }
res := p.PrototypeResponse.Response()
for _, stage := range srv.serve { for _, stage := range srv.serve {
for _, s := range stage { for _, s := range stage {
if ok := srv.handleError(s(req, res), w); !ok { res, err := s(NewRequest(rqb))
if ok := wr.handleError(err); !ok {
return return
} }
header := res.Header() if res.Header != nil {
if header != nil { for key, value := range res.Header {
for key, value := range res.Header() {
for _, v := range value { for _, v := range value {
w.Header().Add(key, v) w.Header().Add(key, v)
} }
} }
} }
if len(res.Body) > 0 {
wr.Write(res.Body)
}
} }
} }
w.Write(res.BodyBytes())
} }

View File

@ -8,6 +8,10 @@ type Values map[string][]string
func (Values) mustEmbedValues() {} func (Values) mustEmbedValues() {}
func (v Values) Parse(data any) error { func (v Values) Marshal(data any) error {
return genericValues(v).Parse(data, "form") return genericValues(v).Marshal(data, "form")
}
func (v Values) Unmarshal(data any) error {
return genericValues(v).Unmarshal(data, "form")
} }