diff --git a/go.mod b/go.mod index e2643ca..647abd7 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.1 require ( somehole.com/common/log v0.1.3 somehole.com/common/security v0.2.1 - somehole.com/service/router v0.13.3 + somehole.com/service/router v0.16.0 ) require ( diff --git a/go.sum b/go.sum index 25ece5b..ea17284 100644 --- a/go.sum +++ b/go.sum @@ -8,5 +8,5 @@ somehole.com/common/log v0.1.3 h1:2PAui0+5EryTAHqVUZQeepLcJTGssHKz2OL+jBknHI0= somehole.com/common/log v0.1.3/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA= somehole.com/common/security v0.2.1 h1:H7TpErYKsCOxKYYie75jU7azO5vEgeqrQ9uA+AjXafA= somehole.com/common/security v0.2.1/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM= -somehole.com/service/router v0.13.3 h1:q4jGnqMIB+o78ZUmKy9HVMFpO/fjRh6Wud3WGIMH1rM= -somehole.com/service/router v0.13.3/go.mod h1:F+/sNY4/ei7C9QD5+iA+gUHugcOAcrWspHKwX8d8oiM= +somehole.com/service/router v0.16.0 h1:h9swbPfyCmXej26nvoXZYIt3RT/sAZPQZnCz1rbLBQk= +somehole.com/service/router v0.16.0/go.mod h1:F+/sNY4/ei7C9QD5+iA+gUHugcOAcrWspHKwX8d8oiM= diff --git a/router.go b/router.go index 0e9f825..f4383ad 100644 --- a/router.go +++ b/router.go @@ -1,8 +1,6 @@ package oauth2 import ( - "fmt" - "somehole.com/common/log" "somehole.com/service/oauth2/server" "somehole.com/service/router" @@ -46,27 +44,9 @@ func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoin } func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) { - pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}} - server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) { - callbackRequest, ok := req.(*server.CallbackRequestBuilder) - if !ok { - panic(fmt.Errorf("expected CallbackRequest, got %T", req)) - } - res, errRes = srv.Callback(callbackRequest) - return - } - ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server)) + ro.Register(ro.CallbackEndpoint, router.NewServer(&server.CallbackRequestBuilder{}, &server.CallbackResponse{}, ro.Logger, srv.Callback)) } func (ro *Router) RegisterTokenServer(srv server.TokenServer) { - pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}} - server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) { - callbackRequest, ok := req.(*server.TokenRequestBuilder) - if !ok { - panic(fmt.Errorf("expected TokenRequest, got %T", req)) - } - res, errRes = srv.Token(callbackRequest) - return - } - ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server)) + ro.Register(ro.TokenEndpoint, router.NewServer(&server.TokenRequestBuilder{}, &server.TokenResponse{}, ro.Logger, srv.Token)) } diff --git a/server/callback.go b/server/callback.go index 57b1e02..af1b475 100644 --- a/server/callback.go +++ b/server/callback.go @@ -15,7 +15,7 @@ type CallbackRequestBuilder struct { } func (rb CallbackRequestBuilder) RequestBuilder() router.RequestBuilder { - rb.DefaultRequestBuilder = router.NewDefaultRequestBuilder(&rb, nil, nil, nil, &rb.values.Values, &rb.values, nil, nil) + rb.DefaultRequestBuilder = router.NewDefaultRequestBuilder(&rb, CallbackRequestBuilder{}, &ErrorHandler{}, nil, nil, nil, &rb.values.Values, &rb.values, nil, nil) return &rb } @@ -28,7 +28,7 @@ type CallbackResponse struct { } func (r CallbackResponse) Response() router.Response { - r.DefaultResponse = router.NewDefaultResponse(&r, nil, nil, &r.Body.Body, &r.Body) + r.DefaultResponse = router.NewDefaultResponse(&r, CallbackResponse{}, nil, nil, &r.Body.Body, &r.Body) return &r } @@ -37,7 +37,7 @@ type UnimplementedCallbackServer struct{} func (UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {} func (UnimplementedCallbackServer) Callback(req *CallbackRequestBuilder) (res *CallbackResponse, errRes router.ErrorResponse) { - errRes = ErrorNotImplemented + errRes = &ErrorHandler{router.DefaultErrorNotImplemented} return } diff --git a/server/error.go b/server/error.go index ad17457..a53d19a 100644 --- a/server/error.go +++ b/server/error.go @@ -2,54 +2,19 @@ package server import ( "encoding/json" - "fmt" - "net/http" + + "somehole.com/service/router" ) -type Error uint32 - -const ( - Ok Error = http.StatusOK - ErrorNotImplemented Error = http.StatusNotImplemented - ErrorMethodNotAllowed Error = http.StatusMethodNotAllowed - ErrorBadRequest Error = http.StatusBadRequest - ErrorUnauthorized Error = http.StatusUnauthorized - ErrorServerError Error = http.StatusInternalServerError -) - -func (e Error) Ok() (ok bool) { - return e == Ok +type ErrorHandler struct { + router.DefaultError } -func (e Error) HttpStatus() (code int) { - return int(e) +func (*ErrorHandler) ErrorResponse(errorResponse router.ErrorResponse) router.ErrorResponse { + return &ErrorHandler{errorResponse.(router.DefaultError)} } -func (e Error) String() (out string) { - switch e { - case Ok: - out = "ok" - case ErrorNotImplemented: - out = "server not implemented" - case ErrorMethodNotAllowed: - out = "method not allowed" - case ErrorBadRequest: - out = "bad request" - case ErrorUnauthorized: - out = "user unauthorized" - case ErrorServerError: - out = "internal server error" - default: - out = "unhandled error" - } - return -} - -func (e Error) Error() (out string) { - return fmt.Sprintf("%s (%d)", e.String(), e.HttpStatus()) -} - -func (e Error) BodyBytes() (body []byte) { +func (e *ErrorHandler) BodyBytes() (body []byte) { body, _ = json.Marshal(struct{ Error string }{Error: e.String()}) return } diff --git a/server/token.go b/server/token.go index cb7703f..3239baa 100644 --- a/server/token.go +++ b/server/token.go @@ -16,7 +16,7 @@ type TokenRequestBuilder struct { } func (rb TokenRequestBuilder) RequestBuilder() router.RequestBuilder { - rb.DefaultRequestBuilder = router.NewDefaultRequestBuilder(&rb, nil, nil, nil, &rb.values.Values, &rb.values, nil, nil) + rb.DefaultRequestBuilder = router.NewDefaultRequestBuilder(&rb, TokenRequestBuilder{}, &ErrorHandler{}, nil, nil, nil, &rb.values.Values, &rb.values, nil, nil) return &rb } @@ -33,7 +33,7 @@ type TokenResponse struct { } func (r TokenResponse) Response() router.Response { - r.DefaultResponse = router.NewDefaultResponse(&r, nil, nil, &r.Body.Body, &r.Body) + r.DefaultResponse = router.NewDefaultResponse(&r, TokenResponse{}, nil, nil, &r.Body.Body, &r.Body) return &r } @@ -42,7 +42,7 @@ type UnimplementedTokenServer struct{} func (UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {} func (UnimplementedTokenServer) Token(req *TokenRequestBuilder) (res *TokenResponse, errRes router.ErrorResponse) { - errRes = ErrorNotImplemented + errRes = &ErrorHandler{router.DefaultErrorNotImplemented} return }