100 lines
2.1 KiB
Go
100 lines
2.1 KiB
Go
package router
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"somehole.com/common/log"
|
|
)
|
|
|
|
type stage uint8
|
|
|
|
const (
|
|
pre stage = iota
|
|
main
|
|
post
|
|
numStages
|
|
)
|
|
|
|
type ServeFunc func(req RequestBuilder, res Response) (errRes ErrorResponse)
|
|
|
|
type server struct {
|
|
prototype Prototype
|
|
logger log.Logger
|
|
serve [numStages][]ServeFunc
|
|
}
|
|
|
|
func NewServer(prototype Prototype, logger log.Logger, serve ServeFunc) (srv *server) {
|
|
srv = &server{
|
|
prototype: prototype,
|
|
logger: logger,
|
|
serve: [numStages][]ServeFunc{},
|
|
}
|
|
srv.serve[main] = append(srv.serve[main], serve)
|
|
return srv
|
|
}
|
|
|
|
func (srv *server) addServeFunc(when stage, serve ServeFunc) *server {
|
|
if srv.serve[when] == nil {
|
|
srv.serve[when] = make([]ServeFunc, 0)
|
|
}
|
|
srv.serve[when] = append(srv.serve[when], serve)
|
|
return srv
|
|
}
|
|
|
|
func (srv *server) PreServeFunc(serve ServeFunc) *server {
|
|
return srv.addServeFunc(pre, serve)
|
|
}
|
|
|
|
func (srv *server) AddServeFunc(serve ServeFunc) *server {
|
|
return srv.addServeFunc(main, serve)
|
|
}
|
|
|
|
func (srv *server) PostServeFunc(serve ServeFunc) *server {
|
|
return srv.addServeFunc(post, serve)
|
|
}
|
|
|
|
func (srv *server) handleError(errRes ErrorResponse, w http.ResponseWriter) (ok bool) {
|
|
if !errRes.Ok() {
|
|
w.WriteHeader(errRes.HttpStatus())
|
|
w.Write(errRes.Bytes())
|
|
srv.logger.Logf(log.LevelError, errRes.String())
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
p := srv.prototype
|
|
req := p.PrototypeRequestBuilder.RequestBuilder()
|
|
if ok := srv.handleError(req.Allowed(r.Method), w); !ok {
|
|
return
|
|
}
|
|
if ok := srv.handleError(req.Header(Header(r.Header)), w); !ok {
|
|
return
|
|
}
|
|
if ok := srv.handleError(req.Body(r.Body), w); !ok {
|
|
return
|
|
}
|
|
r.ParseForm()
|
|
if ok := srv.handleError(req.Values(Values(r.Form)), w); !ok {
|
|
return
|
|
}
|
|
res := p.PrototypeResponse.Response()
|
|
for _, stage := range srv.serve {
|
|
for _, s := range stage {
|
|
if ok := srv.handleError(s(req, res), w); !ok {
|
|
return
|
|
}
|
|
header := res.Header()
|
|
if header != nil {
|
|
for key, value := range res.Header() {
|
|
for _, v := range value {
|
|
w.Header().Add(key, v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
w.Write(res.Bytes())
|
|
}
|