From 0ea4adb96c3b673ed350b1cf19477612528d5529 Mon Sep 17 00:00:00 2001 From: some Date: Tue, 8 Oct 2024 20:33:23 -0400 Subject: [PATCH] Move to using router package --- client/authorization.go | 2 +- client/client.go | 2 +- client/revokation.go | 2 +- client/token.go | 2 +- go.mod | 1 + go.sum | 6 +- router.go | 72 ++++++++++++++++++ server.go | 50 ------------ server/callback.go | 152 ++++++++++++++++-------------------- server/error.go | 63 +++++++++++++++ server/server.go | 75 ------------------ server/token.go | 165 ++++++++++++++++------------------------ 12 files changed, 275 insertions(+), 317 deletions(-) create mode 100644 router.go delete mode 100644 server.go create mode 100644 server/error.go delete mode 100644 server/server.go diff --git a/client/authorization.go b/client/authorization.go index 5d185d3..53ea488 100644 --- a/client/authorization.go +++ b/client/authorization.go @@ -5,7 +5,7 @@ import ( "net/url" "strings" - "somehole.com/common/oauth2/session" + "somehole.com/service/oauth2/session" ) type AuthorizationUrl struct { diff --git a/client/client.go b/client/client.go index a6acc24..a792287 100644 --- a/client/client.go +++ b/client/client.go @@ -1,8 +1,8 @@ package client import ( - "somehole.com/common/oauth2/session" "somehole.com/common/security/signature" + "somehole.com/service/oauth2/session" ) type Client struct { diff --git a/client/revokation.go b/client/revokation.go index 50919cb..c08ed7b 100644 --- a/client/revokation.go +++ b/client/revokation.go @@ -4,7 +4,7 @@ import ( "fmt" "net/url" - "somehole.com/common/oauth2/session" + "somehole.com/service/oauth2/session" ) type TokenRevokationUrl struct { diff --git a/client/token.go b/client/token.go index 0282339..aaf7b7e 100644 --- a/client/token.go +++ b/client/token.go @@ -4,7 +4,7 @@ import ( "fmt" "net/url" - "somehole.com/common/oauth2/session" + "somehole.com/service/oauth2/session" ) type TokenUrl struct { diff --git a/go.mod b/go.mod index a1685dd..5098036 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.23.1 require ( somehole.com/common/log v0.1.3 somehole.com/common/security v0.2.1 + somehole.com/service/router v0.6.4 ) require ( diff --git a/go.sum b/go.sum index 1a58100..cd9a518 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,12 @@ github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY= github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU= -golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIBq1s5hz3HQxhJJSu/SUGlD+SM= -golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= somehole.com/common/log v0.1.3 h1:2PAui0+5EryTAHqVUZQeepLcJTGssHKz2OL+jBknHI0= somehole.com/common/log v0.1.3/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA= somehole.com/common/security v0.2.1 h1:H7TpErYKsCOxKYYie75jU7azO5vEgeqrQ9uA+AjXafA= somehole.com/common/security v0.2.1/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM= +somehole.com/service/router v0.6.4 h1:tA734uvqQqG0LX8UZa7hjmHVMN6LBbNGJraw1CH7kDQ= +somehole.com/service/router v0.6.4/go.mod h1:F+/sNY4/ei7C9QD5+iA+gUHugcOAcrWspHKwX8d8oiM= diff --git a/router.go b/router.go new file mode 100644 index 0000000..0e9f825 --- /dev/null +++ b/router.go @@ -0,0 +1,72 @@ +package oauth2 + +import ( + "fmt" + + "somehole.com/common/log" + "somehole.com/service/oauth2/server" + "somehole.com/service/router" +) + +type Endpoints struct { + CallbackEndpoint string + TokenEndpoint string +} + +var defaultEndpoints = Endpoints{ + CallbackEndpoint: "/callback", + TokenEndpoint: "/token", +} + +type Router struct { + *router.Router + log.Logger + Endpoints +} + +func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoints) (ro *Router) { + var endpoints Endpoints + if r == nil { + r = router.NewRouter(nil, nil) + } + if len(optionalEndpoints) == 1 { + endpoints = optionalEndpoints[0] + } else { + endpoints = defaultEndpoints + } + r.AddRequiredRoutes([]string{endpoints.CallbackEndpoint, endpoints.TokenEndpoint}) + if logger == nil { + logger = log.NewPlainLogger() + } + return &Router{ + Router: r, + Logger: logger, + Endpoints: endpoints, + } +} + +func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) { + pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}} + server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) { + callbackRequest, ok := req.(*server.CallbackRequestBuilder) + if !ok { + panic(fmt.Errorf("expected CallbackRequest, got %T", req)) + } + res, errRes = srv.Callback(callbackRequest) + return + } + ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server)) +} + +func (ro *Router) RegisterTokenServer(srv server.TokenServer) { + pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}} + server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) { + callbackRequest, ok := req.(*server.TokenRequestBuilder) + if !ok { + panic(fmt.Errorf("expected TokenRequest, got %T", req)) + } + res, errRes = srv.Token(callbackRequest) + return + } + ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server)) +} diff --git a/server.go b/server.go deleted file mode 100644 index 08fb1c3..0000000 --- a/server.go +++ /dev/null @@ -1,50 +0,0 @@ -package oauth2 - -import ( - "fmt" - "net/http" - - "somehole.com/common/log" - "somehole.com/common/oauth2/server" -) - -type Server struct { - *http.ServeMux - Logger log.Logger -} - -func NewServer(mux *http.ServeMux, logger log.Logger) *Server { - if mux == nil { - mux = http.NewServeMux() - } - return &Server{ - ServeMux: mux, - Logger: logger, - } -} - -func (s *Server) Handle(pattern string, handler http.Handler) { - s.ServeMux.Handle(pattern, handler) -} - -func (s *Server) RegisterCallbackServer(srv server.CallbackServer) { - s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { - callbackRequest, ok := req.(*server.CallbackRequest) - if !ok { - panic(fmt.Errorf("expected CallbackRequest, got %T", req)) - } - res, errRes = srv.Callback(callbackRequest) - return - })) -} - -func (s *Server) RegisterTokenServer(srv server.TokenServer) { - s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { - tokenRequest, ok := req.(*server.TokenRequest) - if !ok { - panic(fmt.Errorf("expected TokenRequest, got %T", req)) - } - res, errRes = srv.Token(tokenRequest) - return - })) -} diff --git a/server/callback.go b/server/callback.go index c600201..fffdb50 100644 --- a/server/callback.go +++ b/server/callback.go @@ -1,114 +1,96 @@ package server import ( - "fmt" - "net/http" - "net/url" + "io" - "somehole.com/common/oauth2/session" + "somehole.com/service/oauth2/session" + "somehole.com/service/router" ) -const CallbackEndpoint = "/callback" - -type CallbackError uint32 - -const ( - CallbackOk CallbackError = iota - CallbackErrorUnimplemented - CallbackErrorUnauthorized - CallbackErrorServerError -) - -func (ce CallbackError) Ok() bool { - return ce == CallbackOk -} - -func (ce CallbackError) HttpStatus() (code int) { - switch ce { - case CallbackOk: - code = http.StatusOK - case CallbackErrorUnimplemented: - code = http.StatusInternalServerError - case CallbackErrorUnauthorized: - code = http.StatusUnauthorized - case CallbackErrorServerError: - code = http.StatusInternalServerError - default: - code = http.StatusInternalServerError +type CallbackRequestBuilder struct { + allowedMethods []string + header struct { + router.Header } - return -} - -func (ce CallbackError) String() (out string) { - switch ce { - case CallbackOk: - out = "authenticated" - case CallbackErrorUnimplemented: - out = "callback server unimplemented" - case CallbackErrorUnauthorized: - out = "user unauthorized" - case CallbackErrorServerError: - out = "internal server error" - default: - out = "unhandled error" + values struct { + router.Values + State session.State `form:"state"` + Code session.Code `form:"code"` } - return } -func (ce CallbackError) ErrorResponse() []byte { - var msg string - switch ce { - default: - msg = "internal_server_error" +func (crb CallbackRequestBuilder) RequestBuilder() router.RequestBuilder { + return &crb +} + +func (crb *CallbackRequestBuilder) Allowed(method string) (errRes router.ErrorResponse) { + var ok bool + for _, m := range crb.allowedMethods { + if m == method { + ok = true + } } - return mustMarshalJson(struct { - Error string `json:"error"` - }{ - Error: msg, - }) -} - -type CallbackRequest struct { - State session.State - Code session.Code -} - -func (cr CallbackRequest) Request() Request { - return &cr -} - -func (cr *CallbackRequest) Parse(data *url.Values) (err error) { - if !data.Has("code") { - err = fmt.Errorf("missing code paramater") - return + if !ok { + return ErrorMethodNotAllowed } - if !data.Has("state") { - err = fmt.Errorf("missing state parameter") - return + return Ok +} + +func (crb *CallbackRequestBuilder) Header(header router.Header) (errRes router.ErrorResponse) { + err := crb.header.Header.Parse(header) + if err != nil { + return ErrorBadRequest } - cr.State = session.State(data.Get("state")) - cr.Code = session.Code(data.Get("code")) - return + return Ok +} + +func (crb *CallbackRequestBuilder) Body(body io.ReadCloser) (errRes router.ErrorResponse) { + defer body.Close() + return Ok +} + +func (crb *CallbackRequestBuilder) Values(values router.Values) (errRes router.ErrorResponse) { + err := crb.values.Values.Parse(values) + if err != nil { + return ErrorBadRequest + } + return Ok } type CallbackResponse struct { - Message string `json:"message"` + header struct { + router.Header + } + Body struct { + Message string `json:"message"` + } } -func (cr *CallbackResponse) Response() []byte { - return mustMarshalJson(cr) +func (cr CallbackResponse) Response() router.Response { + return &cr +} + +func (cr *CallbackResponse) Header() (header router.Header) { + if cr.header.Header == nil { + cr.header.Header.Parse(cr.header) + } + return cr.header.Header +} + +func (cr *CallbackResponse) Bytes() []byte { + return mustMarshalJson(cr.Body) } type UnimplementedCallbackServer struct{} -func (u UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {} +func (UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {} -func (u UnimplementedCallbackServer) Callback(req *CallbackRequest) (res *CallbackResponse, errRes ErrorResponse) { - errRes = CallbackErrorUnimplemented +func (UnimplementedCallbackServer) Callback(req *CallbackRequestBuilder) (res *CallbackResponse, errRes router.ErrorResponse) { + errRes = ErrorNotImplemented return } type CallbackServer interface { mustEmbedUnimplementedCallbackServer() - Callback(*CallbackRequest) (*CallbackResponse, ErrorResponse) + Callback(*CallbackRequestBuilder) (*CallbackResponse, router.ErrorResponse) } diff --git a/server/error.go b/server/error.go new file mode 100644 index 0000000..2fa25cf --- /dev/null +++ b/server/error.go @@ -0,0 +1,63 @@ +package server + +import ( + "fmt" + "net/http" + "strings" +) + +type Error uint32 + +const ( + Ok Error = http.StatusOK + ErrorNotImplemented Error = http.StatusNotImplemented + ErrorMethodNotAllowed Error = http.StatusMethodNotAllowed + ErrorBadRequest Error = http.StatusBadRequest + ErrorUnauthorized Error = http.StatusUnauthorized + ErrorServerError Error = http.StatusInternalServerError +) + +func (e Error) Ok() (ok bool) { + return e == Ok +} + +func (e Error) HttpStatus() (code int) { + return int(e) +} + +func (e Error) String() (out string) { + switch e { + case Ok: + out = "ok" + case ErrorNotImplemented: + out = "server not implemented" + case ErrorMethodNotAllowed: + out = "method not allowed" + case ErrorBadRequest: + out = "bad request" + case ErrorUnauthorized: + out = "user unauthorized" + case ErrorServerError: + out = "internal server error" + default: + out = "unhandled error" + } + return +} + +func (e Error) Error() (out string) { + return fmt.Sprintf("%s (%d)", e.String(), e.HttpStatus()) +} + +func (e Error) Bytes() (body []byte) { + var msg string + switch e { + default: + msg = strings.Join(strings.Split(e.String(), " "), "_") + } + return mustMarshalJson(struct { + Error string `json:"error"` + }{ + Error: msg, + }) +} diff --git a/server/server.go b/server/server.go deleted file mode 100644 index 22e9ea1..0000000 --- a/server/server.go +++ /dev/null @@ -1,75 +0,0 @@ -package server - -import ( - "net/http" - "net/url" - - "somehole.com/common/log" -) - -type Request interface { - Parse(*url.Values) error -} - -type EmptyRequest interface { - Request() Request -} - -type Response interface { - Response() []byte -} - -type ErrorResponse interface { - Ok() bool - HttpStatus() int - ErrorResponse() []byte - String() string -} - -type server struct { - emptyReq EmptyRequest - allowed []string - logger log.Logger - do func(req Request) (res Response, errRes ErrorResponse) -} - -func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server { - return &server{ - emptyReq: req, - allowed: allowed, - logger: logger, - do: do, - } -} - -type defaultError struct { - Error string `json:"error"` -} - -func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - req := srv.emptyReq - request := req.Request() - var allowed bool - for _, method := range srv.allowed { - if method == r.Method { - allowed = true - break - } - } - if !allowed { - w.WriteHeader(http.StatusMethodNotAllowed) - w.Write(mustMarshalJson(&defaultError{Error: "method_not_allowed"})) - srv.logger.Logf(log.LevelError, "requested method (%s) not one of %v", r.Method, srv.allowed) - return - } - r.ParseForm() - request.Parse(&r.Form) - res, errRes := srv.do(request) - if !errRes.Ok() { - w.WriteHeader(errRes.HttpStatus()) - w.Write(errRes.ErrorResponse()) - srv.logger.Logf(log.LevelError, "request failed: %s", errRes.String()) - } - w.WriteHeader(errRes.HttpStatus()) - w.Write(res.Response()) -} diff --git a/server/token.go b/server/token.go index 16df2f0..3b02f0e 100644 --- a/server/token.go +++ b/server/token.go @@ -1,134 +1,101 @@ package server import ( - "fmt" - "net/http" - "net/url" + "io" - "somehole.com/common/oauth2/session" + "somehole.com/service/oauth2/session" + "somehole.com/service/router" ) -const TokenEndpoint = "/token" - -type TokenError uint32 - -const ( - TokenOk TokenError = iota - TokenErrorUnimplemented - TokenErrorUnauthorized - TokenErrorServerError - TokenErrorSlowDown - TokenErrorPending -) - -func (te TokenError) Ok() bool { - return te == TokenOk +type TokenRequestBuilder struct { + allowedMethods []string + header struct { + router.Header + } + values struct { + router.Values + ResponseType string `form:"response_type"` + ClientId string `form:"client_id"` + Code session.Code `form:"code"` + } } -func (te TokenError) HttpStatus() (code int) { - switch te { - case TokenOk: - code = http.StatusOK - case TokenErrorUnimplemented: - code = http.StatusInternalServerError - case TokenErrorUnauthorized: - code = http.StatusUnauthorized - case TokenErrorServerError: - code = http.StatusInternalServerError - case TokenErrorSlowDown: - code = http.StatusBadRequest - case TokenErrorPending: - code = http.StatusBadRequest - default: - code = http.StatusInternalServerError - } - return +func (trb TokenRequestBuilder) RequestBuilder() router.RequestBuilder { + return &trb } -func (te TokenError) String() (out string) { - switch te { - case TokenOk: - out = "authenticated" - case TokenErrorUnimplemented: - out = "token server unimplemented" - case TokenErrorUnauthorized: - out = "user unauthorized" - case TokenErrorServerError: - out = "internal server error" - case TokenErrorSlowDown: - out = "slow down" - case TokenErrorPending: - out = "authorization pending" - default: - out = "unhandled error" +func (trb *TokenRequestBuilder) Allowed(method string) (errRes router.ErrorResponse) { + var ok bool + for _, m := range trb.allowedMethods { + if m == method { + ok = true + } } - return + if !ok { + return ErrorMethodNotAllowed + } + return Ok } -func (te TokenError) ErrorResponse() []byte { - var msg string - switch te { - case TokenErrorSlowDown: - msg = "slow_down" - case TokenErrorPending: - msg = "authorization_pending" - default: - msg = "internal_server_error" +func (trb *TokenRequestBuilder) Header(header router.Header) (errRes router.ErrorResponse) { + err := trb.header.Header.Parse(header) + if err != nil { + return ErrorBadRequest } - return mustMarshalJson(struct { - Error string `json:"error"` - }{ - Error: msg, - }) + return Ok } -type TokenRequest struct { - ResponseType string - ClientId string - Code session.Code +func (trb *TokenRequestBuilder) Body(body io.ReadCloser) (errRes router.ErrorResponse) { + defer body.Close() + return Ok } -func (tr TokenRequest) Request() Request { - return &tr -} - -func (tr *TokenRequest) Parse(data *url.Values) (err error) { - if !data.Has("ClientId") { - err = fmt.Errorf("missing required client_id paramater") - return +func (trb *TokenRequestBuilder) Values(values router.Values) (errRes router.ErrorResponse) { + err := trb.values.Values.Parse(values) + if err != nil { + return ErrorBadRequest } - tr.ClientId = data.Get("client_id") - if data.Has("code") { - tr.Code = session.Code(data.Get("code")) - } - if data.Has("response_type") { - tr.ResponseType = data.Get("response_type") - } - return + return Ok } type TokenResponse struct { - VerificationUri string `json:"verification_uri"` - UserCode session.Code `json:"user_code"` - DeviceCode session.Code `json:"device_code"` - Interval int `json:"interval"` - ExpiresIn int `json:"expires_in"` + header struct { + router.Header + } + Body struct { + VerificationUri string `json:"verification_uri"` + UserCode session.Code `json:"user_code"` + DeviceCode session.Code `json:"device_code"` + Interval int `json:"interval"` + ExpiresIn int `json:"expires_in"` + } } -func (tr *TokenResponse) Response() []byte { - return mustMarshalJson(tr) +func (tr TokenResponse) Response() router.Response { + return &tr +} + +func (tr *TokenResponse) Header() (header router.Header) { + if tr.header.Header == nil { + tr.header.Header.Parse(tr.header) + } + return tr.header.Header +} + +func (tr *TokenResponse) Bytes() []byte { + return mustMarshalJson(tr.Body) } type UnimplementedTokenServer struct{} -func (u UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {} +func (UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {} -func (u UnimplementedTokenServer) Token(req *TokenRequest) (res *TokenResponse, errRes ErrorResponse) { - errRes = TokenErrorUnimplemented +func (UnimplementedTokenServer) Token(req *TokenRequestBuilder) (res *TokenResponse, errRes router.ErrorResponse) { + errRes = ErrorNotImplemented return } type TokenServer interface { mustEmbedUnimplementedTokenServer() - Token(*TokenRequest) (*TokenResponse, ErrorResponse) + Token(*TokenRequestBuilder) (*TokenResponse, router.ErrorResponse) }