Move to using router package

This commit is contained in:
some 2024-10-08 20:33:23 -04:00
parent 333b70522f
commit 0ea4adb96c
Signed by: some
GPG Key ID: 65D0589220B9BFC8
12 changed files with 275 additions and 317 deletions

View File

@ -5,7 +5,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"somehole.com/common/oauth2/session" "somehole.com/service/oauth2/session"
) )
type AuthorizationUrl struct { type AuthorizationUrl struct {

View File

@ -1,8 +1,8 @@
package client package client
import ( import (
"somehole.com/common/oauth2/session"
"somehole.com/common/security/signature" "somehole.com/common/security/signature"
"somehole.com/service/oauth2/session"
) )
type Client struct { type Client struct {

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"somehole.com/common/oauth2/session" "somehole.com/service/oauth2/session"
) )
type TokenRevokationUrl struct { type TokenRevokationUrl struct {

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"somehole.com/common/oauth2/session" "somehole.com/service/oauth2/session"
) )
type TokenUrl struct { type TokenUrl struct {

1
go.mod
View File

@ -5,6 +5,7 @@ go 1.23.1
require ( require (
somehole.com/common/log v0.1.3 somehole.com/common/log v0.1.3
somehole.com/common/security v0.2.1 somehole.com/common/security v0.2.1
somehole.com/service/router v0.6.4
) )
require ( require (

6
go.sum
View File

@ -1,14 +1,12 @@
github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY= github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY=
github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU= github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU=
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIBq1s5hz3HQxhJJSu/SUGlD+SM=
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
somehole.com/common/log v0.1.3 h1:2PAui0+5EryTAHqVUZQeepLcJTGssHKz2OL+jBknHI0= somehole.com/common/log v0.1.3 h1:2PAui0+5EryTAHqVUZQeepLcJTGssHKz2OL+jBknHI0=
somehole.com/common/log v0.1.3/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA= somehole.com/common/log v0.1.3/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
somehole.com/common/security v0.2.1 h1:H7TpErYKsCOxKYYie75jU7azO5vEgeqrQ9uA+AjXafA= somehole.com/common/security v0.2.1 h1:H7TpErYKsCOxKYYie75jU7azO5vEgeqrQ9uA+AjXafA=
somehole.com/common/security v0.2.1/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM= somehole.com/common/security v0.2.1/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM=
somehole.com/service/router v0.6.4 h1:tA734uvqQqG0LX8UZa7hjmHVMN6LBbNGJraw1CH7kDQ=
somehole.com/service/router v0.6.4/go.mod h1:F+/sNY4/ei7C9QD5+iA+gUHugcOAcrWspHKwX8d8oiM=

72
router.go Normal file
View File

@ -0,0 +1,72 @@
package oauth2
import (
"fmt"
"somehole.com/common/log"
"somehole.com/service/oauth2/server"
"somehole.com/service/router"
)
type Endpoints struct {
CallbackEndpoint string
TokenEndpoint string
}
var defaultEndpoints = Endpoints{
CallbackEndpoint: "/callback",
TokenEndpoint: "/token",
}
type Router struct {
*router.Router
log.Logger
Endpoints
}
func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoints) (ro *Router) {
var endpoints Endpoints
if r == nil {
r = router.NewRouter(nil, nil)
}
if len(optionalEndpoints) == 1 {
endpoints = optionalEndpoints[0]
} else {
endpoints = defaultEndpoints
}
r.AddRequiredRoutes([]string{endpoints.CallbackEndpoint, endpoints.TokenEndpoint})
if logger == nil {
logger = log.NewPlainLogger()
}
return &Router{
Router: r,
Logger: logger,
Endpoints: endpoints,
}
}
func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) {
pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}}
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
callbackRequest, ok := req.(*server.CallbackRequestBuilder)
if !ok {
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
}
res, errRes = srv.Callback(callbackRequest)
return
}
ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server))
}
func (ro *Router) RegisterTokenServer(srv server.TokenServer) {
pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}}
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
callbackRequest, ok := req.(*server.TokenRequestBuilder)
if !ok {
panic(fmt.Errorf("expected TokenRequest, got %T", req))
}
res, errRes = srv.Token(callbackRequest)
return
}
ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server))
}

View File

@ -1,50 +0,0 @@
package oauth2
import (
"fmt"
"net/http"
"somehole.com/common/log"
"somehole.com/common/oauth2/server"
)
type Server struct {
*http.ServeMux
Logger log.Logger
}
func NewServer(mux *http.ServeMux, logger log.Logger) *Server {
if mux == nil {
mux = http.NewServeMux()
}
return &Server{
ServeMux: mux,
Logger: logger,
}
}
func (s *Server) Handle(pattern string, handler http.Handler) {
s.ServeMux.Handle(pattern, handler)
}
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
callbackRequest, ok := req.(*server.CallbackRequest)
if !ok {
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
}
res, errRes = srv.Callback(callbackRequest)
return
}))
}
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
tokenRequest, ok := req.(*server.TokenRequest)
if !ok {
panic(fmt.Errorf("expected TokenRequest, got %T", req))
}
res, errRes = srv.Token(tokenRequest)
return
}))
}

View File

@ -1,114 +1,96 @@
package server package server
import ( import (
"fmt" "io"
"net/http"
"net/url"
"somehole.com/common/oauth2/session" "somehole.com/service/oauth2/session"
"somehole.com/service/router"
) )
const CallbackEndpoint = "/callback" type CallbackRequestBuilder struct {
allowedMethods []string
type CallbackError uint32 header struct {
router.Header
const (
CallbackOk CallbackError = iota
CallbackErrorUnimplemented
CallbackErrorUnauthorized
CallbackErrorServerError
)
func (ce CallbackError) Ok() bool {
return ce == CallbackOk
}
func (ce CallbackError) HttpStatus() (code int) {
switch ce {
case CallbackOk:
code = http.StatusOK
case CallbackErrorUnimplemented:
code = http.StatusInternalServerError
case CallbackErrorUnauthorized:
code = http.StatusUnauthorized
case CallbackErrorServerError:
code = http.StatusInternalServerError
default:
code = http.StatusInternalServerError
} }
return values struct {
} router.Values
State session.State `form:"state"`
func (ce CallbackError) String() (out string) { Code session.Code `form:"code"`
switch ce {
case CallbackOk:
out = "authenticated"
case CallbackErrorUnimplemented:
out = "callback server unimplemented"
case CallbackErrorUnauthorized:
out = "user unauthorized"
case CallbackErrorServerError:
out = "internal server error"
default:
out = "unhandled error"
} }
return
} }
func (ce CallbackError) ErrorResponse() []byte { func (crb CallbackRequestBuilder) RequestBuilder() router.RequestBuilder {
var msg string return &crb
switch ce { }
default:
msg = "internal_server_error" func (crb *CallbackRequestBuilder) Allowed(method string) (errRes router.ErrorResponse) {
var ok bool
for _, m := range crb.allowedMethods {
if m == method {
ok = true
}
} }
return mustMarshalJson(struct { if !ok {
Error string `json:"error"` return ErrorMethodNotAllowed
}{
Error: msg,
})
}
type CallbackRequest struct {
State session.State
Code session.Code
}
func (cr CallbackRequest) Request() Request {
return &cr
}
func (cr *CallbackRequest) Parse(data *url.Values) (err error) {
if !data.Has("code") {
err = fmt.Errorf("missing code paramater")
return
} }
if !data.Has("state") { return Ok
err = fmt.Errorf("missing state parameter") }
return
func (crb *CallbackRequestBuilder) Header(header router.Header) (errRes router.ErrorResponse) {
err := crb.header.Header.Parse(header)
if err != nil {
return ErrorBadRequest
} }
cr.State = session.State(data.Get("state")) return Ok
cr.Code = session.Code(data.Get("code")) }
return
func (crb *CallbackRequestBuilder) Body(body io.ReadCloser) (errRes router.ErrorResponse) {
defer body.Close()
return Ok
}
func (crb *CallbackRequestBuilder) Values(values router.Values) (errRes router.ErrorResponse) {
err := crb.values.Values.Parse(values)
if err != nil {
return ErrorBadRequest
}
return Ok
} }
type CallbackResponse struct { type CallbackResponse struct {
Message string `json:"message"` header struct {
router.Header
}
Body struct {
Message string `json:"message"`
}
} }
func (cr *CallbackResponse) Response() []byte { func (cr CallbackResponse) Response() router.Response {
return mustMarshalJson(cr) return &cr
}
func (cr *CallbackResponse) Header() (header router.Header) {
if cr.header.Header == nil {
cr.header.Header.Parse(cr.header)
}
return cr.header.Header
}
func (cr *CallbackResponse) Bytes() []byte {
return mustMarshalJson(cr.Body)
} }
type UnimplementedCallbackServer struct{} type UnimplementedCallbackServer struct{}
func (u UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {} func (UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {}
func (u UnimplementedCallbackServer) Callback(req *CallbackRequest) (res *CallbackResponse, errRes ErrorResponse) { func (UnimplementedCallbackServer) Callback(req *CallbackRequestBuilder) (res *CallbackResponse, errRes router.ErrorResponse) {
errRes = CallbackErrorUnimplemented errRes = ErrorNotImplemented
return return
} }
type CallbackServer interface { type CallbackServer interface {
mustEmbedUnimplementedCallbackServer() mustEmbedUnimplementedCallbackServer()
Callback(*CallbackRequest) (*CallbackResponse, ErrorResponse) Callback(*CallbackRequestBuilder) (*CallbackResponse, router.ErrorResponse)
} }

63
server/error.go Normal file
View File

@ -0,0 +1,63 @@
package server
import (
"fmt"
"net/http"
"strings"
)
type Error uint32
const (
Ok Error = http.StatusOK
ErrorNotImplemented Error = http.StatusNotImplemented
ErrorMethodNotAllowed Error = http.StatusMethodNotAllowed
ErrorBadRequest Error = http.StatusBadRequest
ErrorUnauthorized Error = http.StatusUnauthorized
ErrorServerError Error = http.StatusInternalServerError
)
func (e Error) Ok() (ok bool) {
return e == Ok
}
func (e Error) HttpStatus() (code int) {
return int(e)
}
func (e Error) String() (out string) {
switch e {
case Ok:
out = "ok"
case ErrorNotImplemented:
out = "server not implemented"
case ErrorMethodNotAllowed:
out = "method not allowed"
case ErrorBadRequest:
out = "bad request"
case ErrorUnauthorized:
out = "user unauthorized"
case ErrorServerError:
out = "internal server error"
default:
out = "unhandled error"
}
return
}
func (e Error) Error() (out string) {
return fmt.Sprintf("%s (%d)", e.String(), e.HttpStatus())
}
func (e Error) Bytes() (body []byte) {
var msg string
switch e {
default:
msg = strings.Join(strings.Split(e.String(), " "), "_")
}
return mustMarshalJson(struct {
Error string `json:"error"`
}{
Error: msg,
})
}

View File

@ -1,75 +0,0 @@
package server
import (
"net/http"
"net/url"
"somehole.com/common/log"
)
type Request interface {
Parse(*url.Values) error
}
type EmptyRequest interface {
Request() Request
}
type Response interface {
Response() []byte
}
type ErrorResponse interface {
Ok() bool
HttpStatus() int
ErrorResponse() []byte
String() string
}
type server struct {
emptyReq EmptyRequest
allowed []string
logger log.Logger
do func(req Request) (res Response, errRes ErrorResponse)
}
func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
return &server{
emptyReq: req,
allowed: allowed,
logger: logger,
do: do,
}
}
type defaultError struct {
Error string `json:"error"`
}
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
req := srv.emptyReq
request := req.Request()
var allowed bool
for _, method := range srv.allowed {
if method == r.Method {
allowed = true
break
}
}
if !allowed {
w.WriteHeader(http.StatusMethodNotAllowed)
w.Write(mustMarshalJson(&defaultError{Error: "method_not_allowed"}))
srv.logger.Logf(log.LevelError, "requested method (%s) not one of %v", r.Method, srv.allowed)
return
}
r.ParseForm()
request.Parse(&r.Form)
res, errRes := srv.do(request)
if !errRes.Ok() {
w.WriteHeader(errRes.HttpStatus())
w.Write(errRes.ErrorResponse())
srv.logger.Logf(log.LevelError, "request failed: %s", errRes.String())
}
w.WriteHeader(errRes.HttpStatus())
w.Write(res.Response())
}

View File

@ -1,134 +1,101 @@
package server package server
import ( import (
"fmt" "io"
"net/http"
"net/url"
"somehole.com/common/oauth2/session" "somehole.com/service/oauth2/session"
"somehole.com/service/router"
) )
const TokenEndpoint = "/token" type TokenRequestBuilder struct {
allowedMethods []string
type TokenError uint32 header struct {
router.Header
const ( }
TokenOk TokenError = iota values struct {
TokenErrorUnimplemented router.Values
TokenErrorUnauthorized ResponseType string `form:"response_type"`
TokenErrorServerError ClientId string `form:"client_id"`
TokenErrorSlowDown Code session.Code `form:"code"`
TokenErrorPending }
)
func (te TokenError) Ok() bool {
return te == TokenOk
} }
func (te TokenError) HttpStatus() (code int) { func (trb TokenRequestBuilder) RequestBuilder() router.RequestBuilder {
switch te { return &trb
case TokenOk:
code = http.StatusOK
case TokenErrorUnimplemented:
code = http.StatusInternalServerError
case TokenErrorUnauthorized:
code = http.StatusUnauthorized
case TokenErrorServerError:
code = http.StatusInternalServerError
case TokenErrorSlowDown:
code = http.StatusBadRequest
case TokenErrorPending:
code = http.StatusBadRequest
default:
code = http.StatusInternalServerError
}
return
} }
func (te TokenError) String() (out string) { func (trb *TokenRequestBuilder) Allowed(method string) (errRes router.ErrorResponse) {
switch te { var ok bool
case TokenOk: for _, m := range trb.allowedMethods {
out = "authenticated" if m == method {
case TokenErrorUnimplemented: ok = true
out = "token server unimplemented" }
case TokenErrorUnauthorized:
out = "user unauthorized"
case TokenErrorServerError:
out = "internal server error"
case TokenErrorSlowDown:
out = "slow down"
case TokenErrorPending:
out = "authorization pending"
default:
out = "unhandled error"
} }
return if !ok {
return ErrorMethodNotAllowed
}
return Ok
} }
func (te TokenError) ErrorResponse() []byte { func (trb *TokenRequestBuilder) Header(header router.Header) (errRes router.ErrorResponse) {
var msg string err := trb.header.Header.Parse(header)
switch te { if err != nil {
case TokenErrorSlowDown: return ErrorBadRequest
msg = "slow_down"
case TokenErrorPending:
msg = "authorization_pending"
default:
msg = "internal_server_error"
} }
return mustMarshalJson(struct { return Ok
Error string `json:"error"`
}{
Error: msg,
})
} }
type TokenRequest struct { func (trb *TokenRequestBuilder) Body(body io.ReadCloser) (errRes router.ErrorResponse) {
ResponseType string defer body.Close()
ClientId string return Ok
Code session.Code
} }
func (tr TokenRequest) Request() Request { func (trb *TokenRequestBuilder) Values(values router.Values) (errRes router.ErrorResponse) {
return &tr err := trb.values.Values.Parse(values)
} if err != nil {
return ErrorBadRequest
func (tr *TokenRequest) Parse(data *url.Values) (err error) {
if !data.Has("ClientId") {
err = fmt.Errorf("missing required client_id paramater")
return
} }
tr.ClientId = data.Get("client_id") return Ok
if data.Has("code") {
tr.Code = session.Code(data.Get("code"))
}
if data.Has("response_type") {
tr.ResponseType = data.Get("response_type")
}
return
} }
type TokenResponse struct { type TokenResponse struct {
VerificationUri string `json:"verification_uri"` header struct {
UserCode session.Code `json:"user_code"` router.Header
DeviceCode session.Code `json:"device_code"` }
Interval int `json:"interval"` Body struct {
ExpiresIn int `json:"expires_in"` VerificationUri string `json:"verification_uri"`
UserCode session.Code `json:"user_code"`
DeviceCode session.Code `json:"device_code"`
Interval int `json:"interval"`
ExpiresIn int `json:"expires_in"`
}
} }
func (tr *TokenResponse) Response() []byte { func (tr TokenResponse) Response() router.Response {
return mustMarshalJson(tr) return &tr
}
func (tr *TokenResponse) Header() (header router.Header) {
if tr.header.Header == nil {
tr.header.Header.Parse(tr.header)
}
return tr.header.Header
}
func (tr *TokenResponse) Bytes() []byte {
return mustMarshalJson(tr.Body)
} }
type UnimplementedTokenServer struct{} type UnimplementedTokenServer struct{}
func (u UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {} func (UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {}
func (u UnimplementedTokenServer) Token(req *TokenRequest) (res *TokenResponse, errRes ErrorResponse) { func (UnimplementedTokenServer) Token(req *TokenRequestBuilder) (res *TokenResponse, errRes router.ErrorResponse) {
errRes = TokenErrorUnimplemented errRes = ErrorNotImplemented
return return
} }
type TokenServer interface { type TokenServer interface {
mustEmbedUnimplementedTokenServer() mustEmbedUnimplementedTokenServer()
Token(*TokenRequest) (*TokenResponse, ErrorResponse) Token(*TokenRequestBuilder) (*TokenResponse, router.ErrorResponse)
} }