Fix empty request
This commit is contained in:
parent
3e1cbec541
commit
b8c6465785
1
go.mod
1
go.mod
@ -3,7 +3,6 @@ module somehole.com/common/oauth2
|
||||
go 1.23.1
|
||||
|
||||
require (
|
||||
somehole.com/common/defaults v0.1.2
|
||||
somehole.com/common/log v0.1.2
|
||||
somehole.com/common/security v0.1.0
|
||||
)
|
||||
|
2
go.sum
2
go.sum
@ -4,8 +4,6 @@ golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIB
|
||||
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
somehole.com/common/defaults v0.1.2 h1:y523TtBEP1y415akn4DELk/7iNxNvawKDyKtQq9QdNk=
|
||||
somehole.com/common/defaults v0.1.2/go.mod h1:jPc/GeCBxJHwPJK8UUh3phDxZ1L9KCd9xXEyGSnRhCo=
|
||||
somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw=
|
||||
somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
|
||||
somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE=
|
||||
|
@ -24,7 +24,7 @@ func NewServer(mux *http.ServeMux, logger log.Logger) *Server {
|
||||
}
|
||||
|
||||
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
||||
s.Handle(server.CallbackEndpoint, server.NewServer(&server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||||
s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||||
callbackRequest, ok := req.(*server.CallbackRequest)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
||||
@ -35,7 +35,7 @@ func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
||||
}
|
||||
|
||||
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
|
||||
s.Handle(server.TokenEndpoint, server.NewServer(&server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||||
s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||||
tokenRequest, ok := req.(*server.TokenRequest)
|
||||
if !ok {
|
||||
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
||||
|
@ -69,6 +69,10 @@ type CallbackRequest struct {
|
||||
Code session.Code
|
||||
}
|
||||
|
||||
func (cr CallbackRequest) Request() Request {
|
||||
return &cr
|
||||
}
|
||||
|
||||
func (cr *CallbackRequest) Parse(data *url.Values) (err error) {
|
||||
if !data.Has("code") {
|
||||
err = fmt.Errorf("missing code paramater")
|
||||
|
@ -4,21 +4,17 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"somehole.com/common/defaults"
|
||||
"somehole.com/common/log"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
req Request
|
||||
allowed []string
|
||||
logger log.Logger
|
||||
do func(req Request) (res Response, errRes ErrorResponse)
|
||||
}
|
||||
|
||||
type Request interface {
|
||||
Parse(*url.Values) error
|
||||
}
|
||||
|
||||
type EmptyRequest interface {
|
||||
Request() Request
|
||||
}
|
||||
|
||||
type Response interface {
|
||||
HttpStatus() int
|
||||
Response() []byte
|
||||
@ -31,12 +27,19 @@ type ErrorResponse interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
func NewServer(req Request, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
|
||||
type server struct {
|
||||
emptyReq EmptyRequest
|
||||
allowed []string
|
||||
logger log.Logger
|
||||
do func(req Request) (res Response, errRes ErrorResponse)
|
||||
}
|
||||
|
||||
func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
|
||||
return &server{
|
||||
req: req,
|
||||
allowed: allowed,
|
||||
logger: logger,
|
||||
do: do,
|
||||
emptyReq: req,
|
||||
allowed: allowed,
|
||||
logger: logger,
|
||||
do: do,
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,13 +48,8 @@ type defaultError struct {
|
||||
}
|
||||
|
||||
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !defaults.Empty(srv.req) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write(mustMarshalJson(&defaultError{Error: "internal_server_error"}))
|
||||
srv.logger.Logf(log.LevelError, "expected empty server request template")
|
||||
return
|
||||
}
|
||||
req := srv.req
|
||||
req := srv.emptyReq
|
||||
request := req.Request()
|
||||
var allowed bool
|
||||
for _, method := range srv.allowed {
|
||||
if method == r.Method {
|
||||
@ -66,8 +64,8 @@ func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
req.Parse(&r.Form)
|
||||
res, errRes := srv.do(req)
|
||||
request.Parse(&r.Form)
|
||||
res, errRes := srv.do(request)
|
||||
if !errRes.Ok() {
|
||||
w.WriteHeader(errRes.HttpStatus())
|
||||
w.Write(errRes.ErrorResponse())
|
||||
|
@ -79,21 +79,27 @@ func (te TokenError) ErrorResponse() []byte {
|
||||
}
|
||||
|
||||
type TokenRequest struct {
|
||||
State session.State
|
||||
Code session.Code
|
||||
ResponseType string
|
||||
ClientId string
|
||||
Code session.Code
|
||||
}
|
||||
|
||||
func (tr TokenRequest) Request() Request {
|
||||
return &tr
|
||||
}
|
||||
|
||||
func (tr *TokenRequest) Parse(data *url.Values) (err error) {
|
||||
if !data.Has("code") {
|
||||
err = fmt.Errorf("missing code paramater")
|
||||
if !data.Has("ClientId") {
|
||||
err = fmt.Errorf("missing required client_id paramater")
|
||||
return
|
||||
}
|
||||
if !data.Has("state") {
|
||||
err = fmt.Errorf("missing state parameter")
|
||||
return
|
||||
tr.ClientId = data.Get("client_id")
|
||||
if data.Has("code") {
|
||||
tr.Code = session.Code(data.Get("code"))
|
||||
}
|
||||
if data.Has("response_type") {
|
||||
tr.ResponseType = data.Get("response_type")
|
||||
}
|
||||
tr.State = session.State(data.Get("state"))
|
||||
tr.Code = session.Code(data.Get("code"))
|
||||
return
|
||||
}
|
||||
|
||||
@ -102,7 +108,8 @@ type TokenResponse struct {
|
||||
VerificationUri string `json:"verification_uri"`
|
||||
UserCode session.Code `json:"user_code"`
|
||||
DeviceCode session.Code `json:"device_code"`
|
||||
Interval uint8 `json:"interval"`
|
||||
Interval int `json:"interval"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
func (tr *TokenResponse) HttpStatus() (code int) {
|
||||
|
Loading…
Reference in New Issue
Block a user