diff --git a/go.mod b/go.mod index 4c40e82..254c34e 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module somehole.com/common/oauth2 go 1.23.1 require ( - somehole.com/common/defaults v0.1.2 somehole.com/common/log v0.1.2 somehole.com/common/security v0.1.0 ) diff --git a/go.sum b/go.sum index 4daa12d..a49faeb 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,6 @@ golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIB golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -somehole.com/common/defaults v0.1.2 h1:y523TtBEP1y415akn4DELk/7iNxNvawKDyKtQq9QdNk= -somehole.com/common/defaults v0.1.2/go.mod h1:jPc/GeCBxJHwPJK8UUh3phDxZ1L9KCd9xXEyGSnRhCo= somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw= somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA= somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE= diff --git a/server.go b/server.go index f063f0f..8fec25c 100644 --- a/server.go +++ b/server.go @@ -24,7 +24,7 @@ func NewServer(mux *http.ServeMux, logger log.Logger) *Server { } func (s *Server) RegisterCallbackServer(srv server.CallbackServer) { - s.Handle(server.CallbackEndpoint, server.NewServer(&server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { + s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { callbackRequest, ok := req.(*server.CallbackRequest) if !ok { panic(fmt.Errorf("expected CallbackRequest, got %T", req)) @@ -35,7 +35,7 @@ func (s *Server) RegisterCallbackServer(srv server.CallbackServer) { } func (s *Server) RegisterTokenServer(srv server.TokenServer) { - s.Handle(server.TokenEndpoint, server.NewServer(&server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { + s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { tokenRequest, ok := req.(*server.TokenRequest) if !ok { panic(fmt.Errorf("expected TokenRequest, got %T", req)) diff --git a/server/callback.go b/server/callback.go index a4df6d1..1481f6f 100644 --- a/server/callback.go +++ b/server/callback.go @@ -69,6 +69,10 @@ type CallbackRequest struct { Code session.Code } +func (cr CallbackRequest) Request() Request { + return &cr +} + func (cr *CallbackRequest) Parse(data *url.Values) (err error) { if !data.Has("code") { err = fmt.Errorf("missing code paramater") diff --git a/server/server.go b/server/server.go index c0370f2..7d77d7d 100644 --- a/server/server.go +++ b/server/server.go @@ -4,21 +4,17 @@ import ( "net/http" "net/url" - "somehole.com/common/defaults" "somehole.com/common/log" ) -type server struct { - req Request - allowed []string - logger log.Logger - do func(req Request) (res Response, errRes ErrorResponse) -} - type Request interface { Parse(*url.Values) error } +type EmptyRequest interface { + Request() Request +} + type Response interface { HttpStatus() int Response() []byte @@ -31,12 +27,19 @@ type ErrorResponse interface { String() string } -func NewServer(req Request, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server { +type server struct { + emptyReq EmptyRequest + allowed []string + logger log.Logger + do func(req Request) (res Response, errRes ErrorResponse) +} + +func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server { return &server{ - req: req, - allowed: allowed, - logger: logger, - do: do, + emptyReq: req, + allowed: allowed, + logger: logger, + do: do, } } @@ -45,13 +48,8 @@ type defaultError struct { } func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if !defaults.Empty(srv.req) { - w.WriteHeader(http.StatusInternalServerError) - w.Write(mustMarshalJson(&defaultError{Error: "internal_server_error"})) - srv.logger.Logf(log.LevelError, "expected empty server request template") - return - } - req := srv.req + req := srv.emptyReq + request := req.Request() var allowed bool for _, method := range srv.allowed { if method == r.Method { @@ -66,8 +64,8 @@ func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } r.ParseForm() - req.Parse(&r.Form) - res, errRes := srv.do(req) + request.Parse(&r.Form) + res, errRes := srv.do(request) if !errRes.Ok() { w.WriteHeader(errRes.HttpStatus()) w.Write(errRes.ErrorResponse()) diff --git a/server/token.go b/server/token.go index 324dfc6..cf87b49 100644 --- a/server/token.go +++ b/server/token.go @@ -79,21 +79,27 @@ func (te TokenError) ErrorResponse() []byte { } type TokenRequest struct { - State session.State - Code session.Code + ResponseType string + ClientId string + Code session.Code +} + +func (tr TokenRequest) Request() Request { + return &tr } func (tr *TokenRequest) Parse(data *url.Values) (err error) { - if !data.Has("code") { - err = fmt.Errorf("missing code paramater") + if !data.Has("ClientId") { + err = fmt.Errorf("missing required client_id paramater") return } - if !data.Has("state") { - err = fmt.Errorf("missing state parameter") - return + tr.ClientId = data.Get("client_id") + if data.Has("code") { + tr.Code = session.Code(data.Get("code")) + } + if data.Has("response_type") { + tr.ResponseType = data.Get("response_type") } - tr.State = session.State(data.Get("state")) - tr.Code = session.Code(data.Get("code")) return } @@ -102,7 +108,8 @@ type TokenResponse struct { VerificationUri string `json:"verification_uri"` UserCode session.Code `json:"user_code"` DeviceCode session.Code `json:"device_code"` - Interval uint8 `json:"interval"` + Interval int `json:"interval"` + ExpiresIn int `json:"expires_in"` } func (tr *TokenResponse) HttpStatus() (code int) {