package router import ( "net/http" "somehole.com/common/log" ) type stage uint8 const ( pre stage = iota main post numStages ) type ServeFunc func(req RequestBuilder, res Response) (errRes ErrorResponse) type server struct { prototype Prototype logger log.Logger serve [numStages][]ServeFunc } func NewServer(prototype Prototype, logger log.Logger, serve ServeFunc) (srv *server) { srv = &server{ prototype: prototype, logger: logger, serve: [numStages][]ServeFunc{}, } srv.serve[main] = append(srv.serve[main], serve) return srv } func (srv *server) addServeFunc(when stage, serve ServeFunc) *server { if srv.serve[when] == nil { srv.serve[when] = make([]ServeFunc, 0) } srv.serve[when] = append(srv.serve[when], serve) return srv } func (srv *server) PreServeFunc(serve ServeFunc) *server { return srv.addServeFunc(pre, serve) } func (srv *server) AddServeFunc(serve ServeFunc) *server { return srv.addServeFunc(main, serve) } func (srv *server) PostServeFunc(serve ServeFunc) *server { return srv.addServeFunc(post, serve) } func (srv *server) handleError(errRes ErrorResponse, w http.ResponseWriter) (ok bool) { if !errRes.Ok() { w.WriteHeader(errRes.HttpStatus()) w.Write(errRes.Bytes()) srv.logger.Logf(log.LevelError, errRes.String()) return false } return true } func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { p := srv.prototype req := p.PrototypeRequestBuilder.RequestBuilder() if ok := srv.handleError(req.Allowed(r.Method), w); !ok { return } if ok := srv.handleError(req.Header(Header(r.Header)), w); !ok { return } if ok := srv.handleError(req.ReadBody(r.Body), w); !ok { return } r.ParseForm() if ok := srv.handleError(req.Values(Values(r.Form)), w); !ok { return } res := p.PrototypeResponse.Response() for _, stage := range srv.serve { for _, s := range stage { if ok := srv.handleError(s(req, res), w); !ok { return } header := res.Header() if header != nil { for key, value := range res.Header() { for _, v := range value { w.Header().Add(key, v) } } } } } w.Write(res.Bytes()) }