commit 4c69ddb99fd6513e7458ad00c3d3c9dbda0ea161 Author: some Date: Tue Oct 1 13:38:43 2024 -0400 Initial commit diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..0cf80d3 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,35 @@ +root = true + +[*] +charset = utf-8 + +end_of_line = LF +insert_final_newline = true +trim_trailing_whitespace = true + +indent_style = space +indent_size = 2 + +[*.sh] +indent_style = space +indent_size = 4 + +[{*.html,*.js,*.css,*.scss}] +indent_style = space +indent_size = 4 + +[Makefile] +indent_style = tab +indent_size = 8 + +[{{*.,}[Dd]ockerfile{.*,},{*.,}[Cc]ontainerfile{.*,}}] +indent_style = space +indent_size = 4 + +[*.proto] +indent_style = space +indent_size = 2 + +[{*.go,go.mod}] +indent_style = tab +indent_size = 8 diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ed0e7b3 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +BSD 2-Clause License + +Copyright (c) 2024, some +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/client/authorization.go b/client/authorization.go new file mode 100644 index 0000000..5d185d3 --- /dev/null +++ b/client/authorization.go @@ -0,0 +1,48 @@ +package client + +import ( + "fmt" + "net/url" + "strings" + + "somehole.com/common/oauth2/session" +) + +type AuthorizationUrl struct { + *Client + *session.Session + ResponseType string +} + +func NewAuthorizationUrl(client *Client, id session.SessionId, responseType string) (url *AuthorizationUrl, err error) { + ses, ok := client.sessions[id] + if !ok { + err = fmt.Errorf("no session found") + } + url = &AuthorizationUrl{ + Client: client, + Session: ses, + ResponseType: responseType, + } + return +} + +func (a *AuthorizationUrl) Url() *url.URL { + v := url.Values{ + "response_type": {a.ResponseType}, + "client_id": {a.ClientId}, + "redirect_uri": {a.RedirectUri}, + "scope": {strings.Join(a.Scopes, " ")}, + "state": {string(a.State)}, + } + return &url.URL{ + Scheme: "https", + Host: a.Host, + Path: a.AuthorizationUrlPath, + RawQuery: v.Encode(), + } +} + +func (a *AuthorizationUrl) String() string { + return a.Url().String() +} diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..a6acc24 --- /dev/null +++ b/client/client.go @@ -0,0 +1,32 @@ +package client + +import ( + "somehole.com/common/oauth2/session" + "somehole.com/common/security/signature" +) + +type Client struct { + *IdentityProvider + *signature.Keypair + ClientId string + ClientSecret string + RedirectUri string + ResponseType string + Scopes []string + sessions map[session.SessionId]*session.Session +} + +func NewClient(idp *IdentityProvider, signer *signature.Keypair, id string, secret string, redirectUri string, responseType string, scopes []string) *Client { + if signer == nil { + signer, _ = signature.NewKeypair() + } + return &Client{ + IdentityProvider: idp, + Keypair: signer, + ClientId: id, + ClientSecret: secret, + RedirectUri: redirectUri, + ResponseType: responseType, + Scopes: scopes, + } +} diff --git a/client/provider.go b/client/provider.go new file mode 100644 index 0000000..f72eec5 --- /dev/null +++ b/client/provider.go @@ -0,0 +1,17 @@ +package client + +type IdentityProvider struct { + Host string + AuthorizationUrlPath string + TokenUrlPath string + TokenRevokationUrlPath string +} + +func NewIdentityProvider(host string, authorizationUrlPath string, tokenUrlPath string, tokenRevocationUrlPath string) *IdentityProvider { + return &IdentityProvider{ + Host: host, + AuthorizationUrlPath: authorizationUrlPath, + TokenUrlPath: tokenUrlPath, + TokenRevokationUrlPath: tokenUrlPath, + } +} diff --git a/client/revokation.go b/client/revokation.go new file mode 100644 index 0000000..50919cb --- /dev/null +++ b/client/revokation.go @@ -0,0 +1,41 @@ +package client + +import ( + "fmt" + "net/url" + + "somehole.com/common/oauth2/session" +) + +type TokenRevokationUrl struct { + *Client + *session.Session + TokenChoice session.TokenChoice +} + +func NewTokenRevokationUrl(client *Client, id session.SessionId, choice session.TokenChoice) (url *TokenRevokationUrl, err error) { + ses, ok := client.sessions[id] + if !ok { + err = fmt.Errorf("no session found") + } + url = &TokenRevokationUrl{ + Client: client, + Session: ses, + TokenChoice: choice, + } + return +} + +func (t *TokenRevokationUrl) Url() *url.URL { + v := url.Values{ + "token": {t.GetToken(t.TokenChoice)}, + "token_type_hint": {t.TokenChoice.String()}, + } + return &url.URL{ + Scheme: "https", + User: url.UserPassword(t.ClientId, t.ClientSecret), + Host: t.Host, + Path: t.TokenUrlPath, + RawQuery: v.Encode(), + } +} diff --git a/client/token.go b/client/token.go new file mode 100644 index 0000000..0282339 --- /dev/null +++ b/client/token.go @@ -0,0 +1,50 @@ +package client + +import ( + "fmt" + "net/url" + + "somehole.com/common/oauth2/session" +) + +type TokenUrl struct { + *Client + *session.Session + TokenChoice session.TokenChoice +} + +func NewTokenUrl(client *Client, id session.SessionId, choice session.TokenChoice) (url *TokenUrl, err error) { + ses, ok := client.sessions[id] + if !ok { + err = fmt.Errorf("no session found") + } + url = &TokenUrl{ + Client: client, + Session: ses, + TokenChoice: choice, + } + return +} + +func (t *TokenUrl) Url() *url.URL { + v := url.Values{} + if t.TokenChoice == session.TokenChoiceAccess { + v.Set("grant_type", "authorization_code") + v.Set("code", string(t.Code)) + v.Set("redirect_uri", string(t.RedirectUri)) + } else if t.TokenChoice == session.TokenChoiceRefresh { + v.Set("grant_type", "refresh_token") + v.Set("refresh_token", string(t.RefreshToken)) + } + return &url.URL{ + Scheme: "https", + User: url.UserPassword(t.ClientId, t.ClientSecret), + Host: t.Host, + Path: t.TokenUrlPath, + RawQuery: v.Encode(), + } +} + +func (t *TokenUrl) String() string { + return t.Url().String() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..4c40e82 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module somehole.com/common/oauth2 + +go 1.23.1 + +require ( + somehole.com/common/defaults v0.1.2 + somehole.com/common/log v0.1.2 + somehole.com/common/security v0.1.0 +) + +require ( + github.com/cloudflare/circl v1.4.0 // indirect + golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d // indirect + golang.org/x/sys v0.10.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4daa12d --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY= +github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU= +golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIBq1s5hz3HQxhJJSu/SUGlD+SM= +golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +somehole.com/common/defaults v0.1.2 h1:y523TtBEP1y415akn4DELk/7iNxNvawKDyKtQq9QdNk= +somehole.com/common/defaults v0.1.2/go.mod h1:jPc/GeCBxJHwPJK8UUh3phDxZ1L9KCd9xXEyGSnRhCo= +somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw= +somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA= +somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE= +somehole.com/common/security v0.1.0/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM= diff --git a/server.go b/server.go new file mode 100644 index 0000000..f063f0f --- /dev/null +++ b/server.go @@ -0,0 +1,46 @@ +package oauth2 + +import ( + "fmt" + "net/http" + + "somehole.com/common/log" + "somehole.com/common/oauth2/server" +) + +type Server struct { + *http.ServeMux + Logger log.Logger +} + +func NewServer(mux *http.ServeMux, logger log.Logger) *Server { + if mux == nil { + mux = http.NewServeMux() + } + return &Server{ + ServeMux: mux, + Logger: logger, + } +} + +func (s *Server) RegisterCallbackServer(srv server.CallbackServer) { + s.Handle(server.CallbackEndpoint, server.NewServer(&server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { + callbackRequest, ok := req.(*server.CallbackRequest) + if !ok { + panic(fmt.Errorf("expected CallbackRequest, got %T", req)) + } + res, errRes = srv.Callback(callbackRequest) + return + })) +} + +func (s *Server) RegisterTokenServer(srv server.TokenServer) { + s.Handle(server.TokenEndpoint, server.NewServer(&server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { + tokenRequest, ok := req.(*server.TokenRequest) + if !ok { + panic(fmt.Errorf("expected TokenRequest, got %T", req)) + } + res, errRes = srv.Token(tokenRequest) + return + })) +} diff --git a/server/callback.go b/server/callback.go new file mode 100644 index 0000000..5828d0c --- /dev/null +++ b/server/callback.go @@ -0,0 +1,112 @@ +package server + +import ( + "fmt" + "net/http" + "net/url" + + "somehole.com/common/oauth2/session" +) + +const CallbackEndpoint = "/callback" + +type CallbackError uint32 + +const ( + CallbackOk CallbackError = iota + CallbackErrorUnimplemented + CallbackErrorUnauthorized + CallbackErrorServerError +) + +func (ce *CallbackError) Ok() bool { + return *ce == CallbackOk +} + +func (ce *CallbackError) HttpStatus() (code int) { + switch *ce { + case CallbackOk: + code = http.StatusOK + case CallbackErrorUnimplemented: + code = http.StatusInternalServerError + case CallbackErrorUnauthorized: + code = http.StatusUnauthorized + case CallbackErrorServerError: + code = http.StatusInternalServerError + } + return +} + +func (ce *CallbackError) String() (out string) { + switch *ce { + case CallbackOk: + out = "authenticated" + case CallbackErrorUnimplemented: + out = "callback server unimplemented" + case CallbackErrorUnauthorized: + out = "user unauthorized" + case CallbackErrorServerError: + out = "internal server error" + } + return +} + +func (ce *CallbackError) ErrorResponse() []byte { + var msg string + switch *ce { + default: + msg = "internal_server_error" + } + return mustMarshalJson(struct { + Error string `json:"error"` + }{ + Error: msg, + }) +} + +type CallbackRequest struct { + State session.State + Code session.Code +} + +func (cr *CallbackRequest) Parse(data *url.Values) (err error) { + if !data.Has("code") { + err = fmt.Errorf("missing code paramater") + return + } + if !data.Has("state") { + err = fmt.Errorf("missing state parameter") + return + } + cr.State = session.State(data.Get("state")) + cr.Code = session.Code(data.Get("code")) + return +} + +type CallbackResponse struct { + Status int `json:"-"` + Message string `json:"message"` +} + +func (cr *CallbackResponse) HttpStatus() (code int) { + return cr.Status +} + +func (cr *CallbackResponse) Response() []byte { + return mustMarshalJson(cr) +} + +type UnimplementedCallbackServer struct{} + +func (u UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {} + +func (u UnimplementedCallbackServer) Callback(callback *CallbackRequest) (callbackResp *CallbackResponse, callbackErr *CallbackError) { + ce := CallbackErrorUnimplemented + callbackErr = &ce + return +} + +type CallbackServer interface { + mustEmbedUnimplementedCallbackServer() + Callback(callback *CallbackRequest) (callbackResp *CallbackResponse, callbackErr *CallbackError) +} diff --git a/server/common.go b/server/common.go new file mode 100644 index 0000000..156602c --- /dev/null +++ b/server/common.go @@ -0,0 +1,14 @@ +package server + +import ( + "encoding/json" + "fmt" +) + +func mustMarshalJson(in any) []byte { + out, err := json.Marshal(in) + if err != nil { + panic(fmt.Errorf("could not marshal %#v: %v", in, err)) + } + return out +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..c0370f2 --- /dev/null +++ b/server/server.go @@ -0,0 +1,78 @@ +package server + +import ( + "net/http" + "net/url" + + "somehole.com/common/defaults" + "somehole.com/common/log" +) + +type server struct { + req Request + allowed []string + logger log.Logger + do func(req Request) (res Response, errRes ErrorResponse) +} + +type Request interface { + Parse(*url.Values) error +} + +type Response interface { + HttpStatus() int + Response() []byte +} + +type ErrorResponse interface { + Ok() bool + HttpStatus() int + ErrorResponse() []byte + String() string +} + +func NewServer(req Request, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server { + return &server{ + req: req, + allowed: allowed, + logger: logger, + do: do, + } +} + +type defaultError struct { + Error string `json:"error"` +} + +func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !defaults.Empty(srv.req) { + w.WriteHeader(http.StatusInternalServerError) + w.Write(mustMarshalJson(&defaultError{Error: "internal_server_error"})) + srv.logger.Logf(log.LevelError, "expected empty server request template") + return + } + req := srv.req + var allowed bool + for _, method := range srv.allowed { + if method == r.Method { + allowed = true + break + } + } + if !allowed { + w.WriteHeader(http.StatusMethodNotAllowed) + w.Write(mustMarshalJson(&defaultError{Error: "method_not_allowed"})) + srv.logger.Logf(log.LevelError, "requested method (%s) not one of %v", r.Method, srv.allowed) + return + } + r.ParseForm() + req.Parse(&r.Form) + res, errRes := srv.do(req) + if !errRes.Ok() { + w.WriteHeader(errRes.HttpStatus()) + w.Write(errRes.ErrorResponse()) + srv.logger.Logf(log.LevelError, "request failed: %s", errRes.String()) + } + w.WriteHeader(res.HttpStatus()) + w.Write(res.Response()) +} diff --git a/server/token.go b/server/token.go new file mode 100644 index 0000000..f674465 --- /dev/null +++ b/server/token.go @@ -0,0 +1,129 @@ +package server + +import ( + "fmt" + "net/http" + "net/url" + + "somehole.com/common/oauth2/session" +) + +const TokenEndpoint = "/token" + +type TokenError uint32 + +const ( + TokenOk TokenError = iota + TokenErrorUnimplemented + TokenErrorUnauthorized + TokenErrorServerError + TokenErrorSlowDown + TokenErrorPending +) + +func (te *TokenError) Ok() bool { + return *te == TokenOk +} + +func (te *TokenError) HttpStatus() (code int) { + switch *te { + case TokenOk: + code = http.StatusOK + case TokenErrorUnimplemented: + code = http.StatusInternalServerError + case TokenErrorUnauthorized: + code = http.StatusUnauthorized + case TokenErrorServerError: + code = http.StatusInternalServerError + case TokenErrorSlowDown: + code = http.StatusBadRequest + case TokenErrorPending: + code = http.StatusBadRequest + } + return +} + +func (te *TokenError) String() (out string) { + switch *te { + case TokenOk: + out = "authenticated" + case TokenErrorUnimplemented: + out = "token server unimplemented" + case TokenErrorUnauthorized: + out = "user unauthorized" + case TokenErrorServerError: + out = "internal server error" + case TokenErrorSlowDown: + out = "slow down" + case TokenErrorPending: + out = "authorization pending" + } + return +} + +func (te *TokenError) ErrorResponse() []byte { + var msg string + switch *te { + case TokenErrorSlowDown: + msg = "slow_down" + case TokenErrorPending: + msg = "authorization_pending" + default: + msg = "internal_server_error" + } + return mustMarshalJson(struct { + Error string `json:"error"` + }{ + Error: msg, + }) +} + +type TokenRequest struct { + State session.State + Code session.Code +} + +func (tr *TokenRequest) Parse(data *url.Values) (err error) { + if !data.Has("code") { + err = fmt.Errorf("missing code paramater") + return + } + if !data.Has("state") { + err = fmt.Errorf("missing state parameter") + return + } + tr.State = session.State(data.Get("state")) + tr.Code = session.Code(data.Get("code")) + return +} + +type TokenResponse struct { + Status int `json:"-"` + VerificationUri string `json:"verification_uri"` + UserCode session.Code `json:"user_code"` + DeviceCode session.Code `json:"device_code"` + Interval uint8 `json:"interval"` +} + +func (tr *TokenResponse) HttpStatus() (code int) { + return tr.Status +} + +func (tr *TokenResponse) Response() []byte { + return mustMarshalJson(tr) +} + +type UnimplementedTokenServer struct{} + +func (u UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {} + +func (u UnimplementedTokenServer) Token(token *TokenRequest) (tokenResp *TokenResponse, tokenErr *TokenError) { + te := TokenErrorUnimplemented + tokenErr = &te + return +} + +type TokenServer interface { + mustEmbedUnimplementedTokenServer() + Token(token *TokenRequest) (tokenResp *TokenResponse, tokenErr *TokenError) +} diff --git a/session/session.go b/session/session.go new file mode 100644 index 0000000..22b8e77 --- /dev/null +++ b/session/session.go @@ -0,0 +1,74 @@ +package session + +import ( + "crypto/rand" + "encoding/hex" + + "somehole.com/common/security/signature" +) + +type SessionId string + +func NewSessionId() SessionId { + b := make([]byte, 8) + rand.Read(b) + return SessionId(hex.EncodeToString(b)) +} + +type State string + +func NewState(id SessionId, signer *signature.Keypair) State { + sig, err := signer.Sign([]byte(id)) + if err != nil { + panic(err) + } + return State(hex.EncodeToString(sig[:])) +} + +type Code string + +type AccessToken string + +type RefreshToken string + +type TokenChoice uint8 + +const ( + TokenChoiceAccess TokenChoice = iota + TokenChoiceRefresh +) + +func (t TokenChoice) String() (out string) { + switch t { + case TokenChoiceAccess: + out = "access_token" + case TokenChoiceRefresh: + out = "refresh_token" + } + return +} + +type Session struct { + SessionId SessionId + State State + Code Code + AccessToken AccessToken + RefreshToken RefreshToken +} + +func NewSession() *Session { + id := NewSessionId() + return &Session{ + SessionId: id, + } +} + +func (s *Session) GetToken(choice TokenChoice) (token string) { + switch choice { + case TokenChoiceAccess: + token = string(s.AccessToken) + case TokenChoiceRefresh: + token = string(s.RefreshToken) + } + return +}