Fix empty request
This commit is contained in:
parent
3e1cbec541
commit
b8c6465785
1
go.mod
1
go.mod
@ -3,7 +3,6 @@ module somehole.com/common/oauth2
|
|||||||
go 1.23.1
|
go 1.23.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
somehole.com/common/defaults v0.1.2
|
|
||||||
somehole.com/common/log v0.1.2
|
somehole.com/common/log v0.1.2
|
||||||
somehole.com/common/security v0.1.0
|
somehole.com/common/security v0.1.0
|
||||||
)
|
)
|
||||||
|
2
go.sum
2
go.sum
@ -4,8 +4,6 @@ golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIB
|
|||||||
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
||||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
somehole.com/common/defaults v0.1.2 h1:y523TtBEP1y415akn4DELk/7iNxNvawKDyKtQq9QdNk=
|
|
||||||
somehole.com/common/defaults v0.1.2/go.mod h1:jPc/GeCBxJHwPJK8UUh3phDxZ1L9KCd9xXEyGSnRhCo=
|
|
||||||
somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw=
|
somehole.com/common/log v0.1.2 h1:e3rHAKL4IR7K79oEB4eNsrUNTad25H25GpPYYBJcDcw=
|
||||||
somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
|
somehole.com/common/log v0.1.2/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
|
||||||
somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE=
|
somehole.com/common/security v0.1.0 h1:Mm4GvO+eV2/qxHSGF5vzcYziEj1SDsDt3c/Vs5+KMGE=
|
||||||
|
@ -24,7 +24,7 @@ func NewServer(mux *http.ServeMux, logger log.Logger) *Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
||||||
s.Handle(server.CallbackEndpoint, server.NewServer(&server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||||||
callbackRequest, ok := req.(*server.CallbackRequest)
|
callbackRequest, ok := req.(*server.CallbackRequest)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
||||||
@ -35,7 +35,7 @@ func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
|
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
|
||||||
s.Handle(server.TokenEndpoint, server.NewServer(&server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||||||
tokenRequest, ok := req.(*server.TokenRequest)
|
tokenRequest, ok := req.(*server.TokenRequest)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
||||||
|
@ -69,6 +69,10 @@ type CallbackRequest struct {
|
|||||||
Code session.Code
|
Code session.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cr CallbackRequest) Request() Request {
|
||||||
|
return &cr
|
||||||
|
}
|
||||||
|
|
||||||
func (cr *CallbackRequest) Parse(data *url.Values) (err error) {
|
func (cr *CallbackRequest) Parse(data *url.Values) (err error) {
|
||||||
if !data.Has("code") {
|
if !data.Has("code") {
|
||||||
err = fmt.Errorf("missing code paramater")
|
err = fmt.Errorf("missing code paramater")
|
||||||
|
@ -4,21 +4,17 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"somehole.com/common/defaults"
|
|
||||||
"somehole.com/common/log"
|
"somehole.com/common/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type server struct {
|
|
||||||
req Request
|
|
||||||
allowed []string
|
|
||||||
logger log.Logger
|
|
||||||
do func(req Request) (res Response, errRes ErrorResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request interface {
|
type Request interface {
|
||||||
Parse(*url.Values) error
|
Parse(*url.Values) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EmptyRequest interface {
|
||||||
|
Request() Request
|
||||||
|
}
|
||||||
|
|
||||||
type Response interface {
|
type Response interface {
|
||||||
HttpStatus() int
|
HttpStatus() int
|
||||||
Response() []byte
|
Response() []byte
|
||||||
@ -31,9 +27,16 @@ type ErrorResponse interface {
|
|||||||
String() string
|
String() string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(req Request, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
|
type server struct {
|
||||||
|
emptyReq EmptyRequest
|
||||||
|
allowed []string
|
||||||
|
logger log.Logger
|
||||||
|
do func(req Request) (res Response, errRes ErrorResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
|
||||||
return &server{
|
return &server{
|
||||||
req: req,
|
emptyReq: req,
|
||||||
allowed: allowed,
|
allowed: allowed,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
do: do,
|
do: do,
|
||||||
@ -45,13 +48,8 @@ type defaultError struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if !defaults.Empty(srv.req) {
|
req := srv.emptyReq
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
request := req.Request()
|
||||||
w.Write(mustMarshalJson(&defaultError{Error: "internal_server_error"}))
|
|
||||||
srv.logger.Logf(log.LevelError, "expected empty server request template")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req := srv.req
|
|
||||||
var allowed bool
|
var allowed bool
|
||||||
for _, method := range srv.allowed {
|
for _, method := range srv.allowed {
|
||||||
if method == r.Method {
|
if method == r.Method {
|
||||||
@ -66,8 +64,8 @@ func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
r.ParseForm()
|
r.ParseForm()
|
||||||
req.Parse(&r.Form)
|
request.Parse(&r.Form)
|
||||||
res, errRes := srv.do(req)
|
res, errRes := srv.do(request)
|
||||||
if !errRes.Ok() {
|
if !errRes.Ok() {
|
||||||
w.WriteHeader(errRes.HttpStatus())
|
w.WriteHeader(errRes.HttpStatus())
|
||||||
w.Write(errRes.ErrorResponse())
|
w.Write(errRes.ErrorResponse())
|
||||||
|
@ -79,21 +79,27 @@ func (te TokenError) ErrorResponse() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
type TokenRequest struct {
|
||||||
State session.State
|
ResponseType string
|
||||||
|
ClientId string
|
||||||
Code session.Code
|
Code session.Code
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tr TokenRequest) Request() Request {
|
||||||
|
return &tr
|
||||||
|
}
|
||||||
|
|
||||||
func (tr *TokenRequest) Parse(data *url.Values) (err error) {
|
func (tr *TokenRequest) Parse(data *url.Values) (err error) {
|
||||||
if !data.Has("code") {
|
if !data.Has("ClientId") {
|
||||||
err = fmt.Errorf("missing code paramater")
|
err = fmt.Errorf("missing required client_id paramater")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !data.Has("state") {
|
tr.ClientId = data.Get("client_id")
|
||||||
err = fmt.Errorf("missing state parameter")
|
if data.Has("code") {
|
||||||
return
|
|
||||||
}
|
|
||||||
tr.State = session.State(data.Get("state"))
|
|
||||||
tr.Code = session.Code(data.Get("code"))
|
tr.Code = session.Code(data.Get("code"))
|
||||||
|
}
|
||||||
|
if data.Has("response_type") {
|
||||||
|
tr.ResponseType = data.Get("response_type")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,7 +108,8 @@ type TokenResponse struct {
|
|||||||
VerificationUri string `json:"verification_uri"`
|
VerificationUri string `json:"verification_uri"`
|
||||||
UserCode session.Code `json:"user_code"`
|
UserCode session.Code `json:"user_code"`
|
||||||
DeviceCode session.Code `json:"device_code"`
|
DeviceCode session.Code `json:"device_code"`
|
||||||
Interval uint8 `json:"interval"`
|
Interval int `json:"interval"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *TokenResponse) HttpStatus() (code int) {
|
func (tr *TokenResponse) HttpStatus() (code int) {
|
||||||
|
Loading…
Reference in New Issue
Block a user