Fix empty request

This commit is contained in:
some 2024-10-01 18:57:46 -04:00
parent 3e1cbec541
commit b8c6465785
Signed by: some
GPG Key ID: 65D0589220B9BFC8
6 changed files with 43 additions and 37 deletions

1
go.mod
View File

@ -3,7 +3,6 @@ module somehole.com/common/oauth2
go 1.23.1 go 1.23.1
require ( require (
somehole.com/common/defaults v0.1.2
somehole.com/common/log v0.1.2 somehole.com/common/log v0.1.2
somehole.com/common/security v0.1.0 somehole.com/common/security v0.1.0
) )

2
go.sum
View File

@ -4,8 +4,6 @@ golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIB
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
somehole.com/common/defaults v0.1.2 h1:y523TtBEP1y415akn4DELk/7iNxNvawKDyKtQq9QdNk=
somehole.com/common/defaults v0.1.2/go.mod h1:jPc/GeCBxJHwPJK8UUh3phDxZ1L9KCd9xXEyGSnRhCo=
somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw= somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw=
somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA= somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE= somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE=

View File

@ -24,7 +24,7 @@ func NewServer(mux *http.ServeMux, logger log.Logger) *Server {
} }
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) { func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
s.Handle(server.CallbackEndpoint, server.NewServer(&server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
callbackRequest, ok := req.(*server.CallbackRequest) callbackRequest, ok := req.(*server.CallbackRequest)
if !ok { if !ok {
panic(fmt.Errorf("expected CallbackRequest, got %T", req)) panic(fmt.Errorf("expected CallbackRequest, got %T", req))
@ -35,7 +35,7 @@ func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
} }
func (s *Server) RegisterTokenServer(srv server.TokenServer) { func (s *Server) RegisterTokenServer(srv server.TokenServer) {
s.Handle(server.TokenEndpoint, server.NewServer(&server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) { s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
tokenRequest, ok := req.(*server.TokenRequest) tokenRequest, ok := req.(*server.TokenRequest)
if !ok { if !ok {
panic(fmt.Errorf("expected TokenRequest, got %T", req)) panic(fmt.Errorf("expected TokenRequest, got %T", req))

View File

@ -69,6 +69,10 @@ type CallbackRequest struct {
Code session.Code Code session.Code
} }
func (cr CallbackRequest) Request() Request {
return &cr
}
func (cr *CallbackRequest) Parse(data *url.Values) (err error) { func (cr *CallbackRequest) Parse(data *url.Values) (err error) {
if !data.Has("code") { if !data.Has("code") {
err = fmt.Errorf("missing code paramater") err = fmt.Errorf("missing code paramater")

View File

@ -4,21 +4,17 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"somehole.com/common/defaults"
"somehole.com/common/log" "somehole.com/common/log"
) )
type server struct {
req Request
allowed []string
logger log.Logger
do func(req Request) (res Response, errRes ErrorResponse)
}
type Request interface { type Request interface {
Parse(*url.Values) error Parse(*url.Values) error
} }
type EmptyRequest interface {
Request() Request
}
type Response interface { type Response interface {
HttpStatus() int HttpStatus() int
Response() []byte Response() []byte
@ -31,9 +27,16 @@ type ErrorResponse interface {
String() string String() string
} }
func NewServer(req Request, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server { type server struct {
emptyReq EmptyRequest
allowed []string
logger log.Logger
do func(req Request) (res Response, errRes ErrorResponse)
}
func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
return &server{ return &server{
req: req, emptyReq: req,
allowed: allowed, allowed: allowed,
logger: logger, logger: logger,
do: do, do: do,
@ -45,13 +48,8 @@ type defaultError struct {
} }
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !defaults.Empty(srv.req) { req := srv.emptyReq
w.WriteHeader(http.StatusInternalServerError) request := req.Request()
w.Write(mustMarshalJson(&defaultError{Error: "internal_server_error"}))
srv.logger.Logf(log.LevelError, "expected empty server request template")
return
}
req := srv.req
var allowed bool var allowed bool
for _, method := range srv.allowed { for _, method := range srv.allowed {
if method == r.Method { if method == r.Method {
@ -66,8 +64,8 @@ func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
r.ParseForm() r.ParseForm()
req.Parse(&r.Form) request.Parse(&r.Form)
res, errRes := srv.do(req) res, errRes := srv.do(request)
if !errRes.Ok() { if !errRes.Ok() {
w.WriteHeader(errRes.HttpStatus()) w.WriteHeader(errRes.HttpStatus())
w.Write(errRes.ErrorResponse()) w.Write(errRes.ErrorResponse())

View File

@ -79,21 +79,27 @@ func (te TokenError) ErrorResponse() []byte {
} }
type TokenRequest struct { type TokenRequest struct {
State session.State ResponseType string
ClientId string
Code session.Code Code session.Code
} }
func (tr TokenRequest) Request() Request {
return &tr
}
func (tr *TokenRequest) Parse(data *url.Values) (err error) { func (tr *TokenRequest) Parse(data *url.Values) (err error) {
if !data.Has("code") { if !data.Has("ClientId") {
err = fmt.Errorf("missing code paramater") err = fmt.Errorf("missing required client_id paramater")
return return
} }
if !data.Has("state") { tr.ClientId = data.Get("client_id")
err = fmt.Errorf("missing state parameter") if data.Has("code") {
return
}
tr.State = session.State(data.Get("state"))
tr.Code = session.Code(data.Get("code")) tr.Code = session.Code(data.Get("code"))
}
if data.Has("response_type") {
tr.ResponseType = data.Get("response_type")
}
return return
} }
@ -102,7 +108,8 @@ type TokenResponse struct {
VerificationUri string `json:"verification_uri"` VerificationUri string `json:"verification_uri"`
UserCode session.Code `json:"user_code"` UserCode session.Code `json:"user_code"`
DeviceCode session.Code `json:"device_code"` DeviceCode session.Code `json:"device_code"`
Interval uint8 `json:"interval"` Interval int `json:"interval"`
ExpiresIn int `json:"expires_in"`
} }
func (tr *TokenResponse) HttpStatus() (code int) { func (tr *TokenResponse) HttpStatus() (code int) {