Move to using router package
This commit is contained in:
parent
333b70522f
commit
0ea4adb96c
@ -5,7 +5,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"somehole.com/common/oauth2/session"
|
"somehole.com/service/oauth2/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AuthorizationUrl struct {
|
type AuthorizationUrl struct {
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"somehole.com/common/oauth2/session"
|
|
||||||
"somehole.com/common/security/signature"
|
"somehole.com/common/security/signature"
|
||||||
|
"somehole.com/service/oauth2/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"somehole.com/common/oauth2/session"
|
"somehole.com/service/oauth2/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenRevokationUrl struct {
|
type TokenRevokationUrl struct {
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"somehole.com/common/oauth2/session"
|
"somehole.com/service/oauth2/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TokenUrl struct {
|
type TokenUrl struct {
|
||||||
|
1
go.mod
1
go.mod
@ -5,6 +5,7 @@ go 1.23.1
|
|||||||
require (
|
require (
|
||||||
somehole.com/common/log v0.1.3
|
somehole.com/common/log v0.1.3
|
||||||
somehole.com/common/security v0.2.1
|
somehole.com/common/security v0.2.1
|
||||||
|
somehole.com/service/router v0.6.4
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
6
go.sum
6
go.sum
@ -1,14 +1,12 @@
|
|||||||
github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY=
|
github.com/cloudflare/circl v1.4.0 h1:BV7h5MgrktNzytKmWjpOtdYrf0lkkbF8YMlBGPhJQrY=
|
||||||
github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU=
|
github.com/cloudflare/circl v1.4.0/go.mod h1:PDRU+oXvdD7KCtgKxW95M5Z8BpSCJXQORiZFnBQS5QU=
|
||||||
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d h1:LiA25/KWKuXfIq5pMIBq1s5hz3HQxhJJSu/SUGlD+SM=
|
|
||||||
golang.org/x/crypto v0.11.1-0.20230711161743-2e82bdd1719d/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
|
|
||||||
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
|
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
|
||||||
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
|
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
|
||||||
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
|
|
||||||
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||||
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
somehole.com/common/log v0.1.3 h1:2PAui0+5EryTAHqVUZQeepLcJTGssHKz2OL+jBknHI0=
|
somehole.com/common/log v0.1.3 h1:2PAui0+5EryTAHqVUZQeepLcJTGssHKz2OL+jBknHI0=
|
||||||
somehole.com/common/log v0.1.3/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
|
somehole.com/common/log v0.1.3/go.mod h1:NS2eHnN120GA6oFbBm3XhB5yHww0eXTbLuMQYZxYNyA=
|
||||||
somehole.com/common/security v0.2.1 h1:H7TpErYKsCOxKYYie75jU7azO5vEgeqrQ9uA+AjXafA=
|
somehole.com/common/security v0.2.1 h1:H7TpErYKsCOxKYYie75jU7azO5vEgeqrQ9uA+AjXafA=
|
||||||
somehole.com/common/security v0.2.1/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM=
|
somehole.com/common/security v0.2.1/go.mod h1:6SMKdIrfxT460XXWafpD6A9yT/ykwipGsGztA5g5vVM=
|
||||||
|
somehole.com/service/router v0.6.4 h1:tA734uvqQqG0LX8UZa7hjmHVMN6LBbNGJraw1CH7kDQ=
|
||||||
|
somehole.com/service/router v0.6.4/go.mod h1:F+/sNY4/ei7C9QD5+iA+gUHugcOAcrWspHKwX8d8oiM=
|
||||||
|
72
router.go
Normal file
72
router.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"somehole.com/common/log"
|
||||||
|
"somehole.com/service/oauth2/server"
|
||||||
|
"somehole.com/service/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Endpoints struct {
|
||||||
|
CallbackEndpoint string
|
||||||
|
TokenEndpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultEndpoints = Endpoints{
|
||||||
|
CallbackEndpoint: "/callback",
|
||||||
|
TokenEndpoint: "/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
type Router struct {
|
||||||
|
*router.Router
|
||||||
|
log.Logger
|
||||||
|
Endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoints) (ro *Router) {
|
||||||
|
var endpoints Endpoints
|
||||||
|
if r == nil {
|
||||||
|
r = router.NewRouter(nil, nil)
|
||||||
|
}
|
||||||
|
if len(optionalEndpoints) == 1 {
|
||||||
|
endpoints = optionalEndpoints[0]
|
||||||
|
} else {
|
||||||
|
endpoints = defaultEndpoints
|
||||||
|
}
|
||||||
|
r.AddRequiredRoutes([]string{endpoints.CallbackEndpoint, endpoints.TokenEndpoint})
|
||||||
|
if logger == nil {
|
||||||
|
logger = log.NewPlainLogger()
|
||||||
|
}
|
||||||
|
return &Router{
|
||||||
|
Router: r,
|
||||||
|
Logger: logger,
|
||||||
|
Endpoints: endpoints,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) {
|
||||||
|
pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}}
|
||||||
|
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
|
||||||
|
callbackRequest, ok := req.(*server.CallbackRequestBuilder)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
||||||
|
}
|
||||||
|
res, errRes = srv.Callback(callbackRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ro *Router) RegisterTokenServer(srv server.TokenServer) {
|
||||||
|
pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}}
|
||||||
|
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
|
||||||
|
callbackRequest, ok := req.(*server.TokenRequestBuilder)
|
||||||
|
if !ok {
|
||||||
|
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
||||||
|
}
|
||||||
|
res, errRes = srv.Token(callbackRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server))
|
||||||
|
}
|
50
server.go
50
server.go
@ -1,50 +0,0 @@
|
|||||||
package oauth2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"somehole.com/common/log"
|
|
||||||
"somehole.com/common/oauth2/server"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Server struct {
|
|
||||||
*http.ServeMux
|
|
||||||
Logger log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServer(mux *http.ServeMux, logger log.Logger) *Server {
|
|
||||||
if mux == nil {
|
|
||||||
mux = http.NewServeMux()
|
|
||||||
}
|
|
||||||
return &Server{
|
|
||||||
ServeMux: mux,
|
|
||||||
Logger: logger,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) Handle(pattern string, handler http.Handler) {
|
|
||||||
s.ServeMux.Handle(pattern, handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
|
||||||
s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
|
||||||
callbackRequest, ok := req.(*server.CallbackRequest)
|
|
||||||
if !ok {
|
|
||||||
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
|
||||||
}
|
|
||||||
res, errRes = srv.Callback(callbackRequest)
|
|
||||||
return
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
|
|
||||||
s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
|
||||||
tokenRequest, ok := req.(*server.TokenRequest)
|
|
||||||
if !ok {
|
|
||||||
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
|
||||||
}
|
|
||||||
res, errRes = srv.Token(tokenRequest)
|
|
||||||
return
|
|
||||||
}))
|
|
||||||
}
|
|
@ -1,114 +1,96 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"io"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"somehole.com/common/oauth2/session"
|
"somehole.com/service/oauth2/session"
|
||||||
|
"somehole.com/service/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
const CallbackEndpoint = "/callback"
|
type CallbackRequestBuilder struct {
|
||||||
|
allowedMethods []string
|
||||||
type CallbackError uint32
|
header struct {
|
||||||
|
router.Header
|
||||||
const (
|
|
||||||
CallbackOk CallbackError = iota
|
|
||||||
CallbackErrorUnimplemented
|
|
||||||
CallbackErrorUnauthorized
|
|
||||||
CallbackErrorServerError
|
|
||||||
)
|
|
||||||
|
|
||||||
func (ce CallbackError) Ok() bool {
|
|
||||||
return ce == CallbackOk
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ce CallbackError) HttpStatus() (code int) {
|
|
||||||
switch ce {
|
|
||||||
case CallbackOk:
|
|
||||||
code = http.StatusOK
|
|
||||||
case CallbackErrorUnimplemented:
|
|
||||||
code = http.StatusInternalServerError
|
|
||||||
case CallbackErrorUnauthorized:
|
|
||||||
code = http.StatusUnauthorized
|
|
||||||
case CallbackErrorServerError:
|
|
||||||
code = http.StatusInternalServerError
|
|
||||||
default:
|
|
||||||
code = http.StatusInternalServerError
|
|
||||||
}
|
}
|
||||||
return
|
values struct {
|
||||||
}
|
router.Values
|
||||||
|
State session.State `form:"state"`
|
||||||
func (ce CallbackError) String() (out string) {
|
Code session.Code `form:"code"`
|
||||||
switch ce {
|
|
||||||
case CallbackOk:
|
|
||||||
out = "authenticated"
|
|
||||||
case CallbackErrorUnimplemented:
|
|
||||||
out = "callback server unimplemented"
|
|
||||||
case CallbackErrorUnauthorized:
|
|
||||||
out = "user unauthorized"
|
|
||||||
case CallbackErrorServerError:
|
|
||||||
out = "internal server error"
|
|
||||||
default:
|
|
||||||
out = "unhandled error"
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ce CallbackError) ErrorResponse() []byte {
|
func (crb CallbackRequestBuilder) RequestBuilder() router.RequestBuilder {
|
||||||
var msg string
|
return &crb
|
||||||
switch ce {
|
}
|
||||||
default:
|
|
||||||
msg = "internal_server_error"
|
func (crb *CallbackRequestBuilder) Allowed(method string) (errRes router.ErrorResponse) {
|
||||||
|
var ok bool
|
||||||
|
for _, m := range crb.allowedMethods {
|
||||||
|
if m == method {
|
||||||
|
ok = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return mustMarshalJson(struct {
|
if !ok {
|
||||||
Error string `json:"error"`
|
return ErrorMethodNotAllowed
|
||||||
}{
|
|
||||||
Error: msg,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type CallbackRequest struct {
|
|
||||||
State session.State
|
|
||||||
Code session.Code
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cr CallbackRequest) Request() Request {
|
|
||||||
return &cr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cr *CallbackRequest) Parse(data *url.Values) (err error) {
|
|
||||||
if !data.Has("code") {
|
|
||||||
err = fmt.Errorf("missing code paramater")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if !data.Has("state") {
|
return Ok
|
||||||
err = fmt.Errorf("missing state parameter")
|
}
|
||||||
return
|
|
||||||
|
func (crb *CallbackRequestBuilder) Header(header router.Header) (errRes router.ErrorResponse) {
|
||||||
|
err := crb.header.Header.Parse(header)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorBadRequest
|
||||||
}
|
}
|
||||||
cr.State = session.State(data.Get("state"))
|
return Ok
|
||||||
cr.Code = session.Code(data.Get("code"))
|
}
|
||||||
return
|
|
||||||
|
func (crb *CallbackRequestBuilder) Body(body io.ReadCloser) (errRes router.ErrorResponse) {
|
||||||
|
defer body.Close()
|
||||||
|
return Ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (crb *CallbackRequestBuilder) Values(values router.Values) (errRes router.ErrorResponse) {
|
||||||
|
err := crb.values.Values.Parse(values)
|
||||||
|
if err != nil {
|
||||||
|
return ErrorBadRequest
|
||||||
|
}
|
||||||
|
return Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackResponse struct {
|
type CallbackResponse struct {
|
||||||
Message string `json:"message"`
|
header struct {
|
||||||
|
router.Header
|
||||||
|
}
|
||||||
|
Body struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cr *CallbackResponse) Response() []byte {
|
func (cr CallbackResponse) Response() router.Response {
|
||||||
return mustMarshalJson(cr)
|
return &cr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *CallbackResponse) Header() (header router.Header) {
|
||||||
|
if cr.header.Header == nil {
|
||||||
|
cr.header.Header.Parse(cr.header)
|
||||||
|
}
|
||||||
|
return cr.header.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cr *CallbackResponse) Bytes() []byte {
|
||||||
|
return mustMarshalJson(cr.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UnimplementedCallbackServer struct{}
|
type UnimplementedCallbackServer struct{}
|
||||||
|
|
||||||
func (u UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {}
|
func (UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {}
|
||||||
|
|
||||||
func (u UnimplementedCallbackServer) Callback(req *CallbackRequest) (res *CallbackResponse, errRes ErrorResponse) {
|
func (UnimplementedCallbackServer) Callback(req *CallbackRequestBuilder) (res *CallbackResponse, errRes router.ErrorResponse) {
|
||||||
errRes = CallbackErrorUnimplemented
|
errRes = ErrorNotImplemented
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackServer interface {
|
type CallbackServer interface {
|
||||||
mustEmbedUnimplementedCallbackServer()
|
mustEmbedUnimplementedCallbackServer()
|
||||||
Callback(*CallbackRequest) (*CallbackResponse, ErrorResponse)
|
Callback(*CallbackRequestBuilder) (*CallbackResponse, router.ErrorResponse)
|
||||||
}
|
}
|
||||||
|
63
server/error.go
Normal file
63
server/error.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Error uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
Ok Error = http.StatusOK
|
||||||
|
ErrorNotImplemented Error = http.StatusNotImplemented
|
||||||
|
ErrorMethodNotAllowed Error = http.StatusMethodNotAllowed
|
||||||
|
ErrorBadRequest Error = http.StatusBadRequest
|
||||||
|
ErrorUnauthorized Error = http.StatusUnauthorized
|
||||||
|
ErrorServerError Error = http.StatusInternalServerError
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e Error) Ok() (ok bool) {
|
||||||
|
return e == Ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Error) HttpStatus() (code int) {
|
||||||
|
return int(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Error) String() (out string) {
|
||||||
|
switch e {
|
||||||
|
case Ok:
|
||||||
|
out = "ok"
|
||||||
|
case ErrorNotImplemented:
|
||||||
|
out = "server not implemented"
|
||||||
|
case ErrorMethodNotAllowed:
|
||||||
|
out = "method not allowed"
|
||||||
|
case ErrorBadRequest:
|
||||||
|
out = "bad request"
|
||||||
|
case ErrorUnauthorized:
|
||||||
|
out = "user unauthorized"
|
||||||
|
case ErrorServerError:
|
||||||
|
out = "internal server error"
|
||||||
|
default:
|
||||||
|
out = "unhandled error"
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Error) Error() (out string) {
|
||||||
|
return fmt.Sprintf("%s (%d)", e.String(), e.HttpStatus())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e Error) Bytes() (body []byte) {
|
||||||
|
var msg string
|
||||||
|
switch e {
|
||||||
|
default:
|
||||||
|
msg = strings.Join(strings.Split(e.String(), " "), "_")
|
||||||
|
}
|
||||||
|
return mustMarshalJson(struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}{
|
||||||
|
Error: msg,
|
||||||
|
})
|
||||||
|
}
|
@ -1,75 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"somehole.com/common/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Request interface {
|
|
||||||
Parse(*url.Values) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type EmptyRequest interface {
|
|
||||||
Request() Request
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response interface {
|
|
||||||
Response() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrorResponse interface {
|
|
||||||
Ok() bool
|
|
||||||
HttpStatus() int
|
|
||||||
ErrorResponse() []byte
|
|
||||||
String() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type server struct {
|
|
||||||
emptyReq EmptyRequest
|
|
||||||
allowed []string
|
|
||||||
logger log.Logger
|
|
||||||
do func(req Request) (res Response, errRes ErrorResponse)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServer(req EmptyRequest, allowed []string, logger log.Logger, do func(req Request) (res Response, errRes ErrorResponse)) *server {
|
|
||||||
return &server{
|
|
||||||
emptyReq: req,
|
|
||||||
allowed: allowed,
|
|
||||||
logger: logger,
|
|
||||||
do: do,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type defaultError struct {
|
|
||||||
Error string `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
req := srv.emptyReq
|
|
||||||
request := req.Request()
|
|
||||||
var allowed bool
|
|
||||||
for _, method := range srv.allowed {
|
|
||||||
if method == r.Method {
|
|
||||||
allowed = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !allowed {
|
|
||||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
||||||
w.Write(mustMarshalJson(&defaultError{Error: "method_not_allowed"}))
|
|
||||||
srv.logger.Logf(log.LevelError, "requested method (%s) not one of %v", r.Method, srv.allowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.ParseForm()
|
|
||||||
request.Parse(&r.Form)
|
|
||||||
res, errRes := srv.do(request)
|
|
||||||
if !errRes.Ok() {
|
|
||||||
w.WriteHeader(errRes.HttpStatus())
|
|
||||||
w.Write(errRes.ErrorResponse())
|
|
||||||
srv.logger.Logf(log.LevelError, "request failed: %s", errRes.String())
|
|
||||||
}
|
|
||||||
w.WriteHeader(errRes.HttpStatus())
|
|
||||||
w.Write(res.Response())
|
|
||||||
}
|
|
165
server/token.go
165
server/token.go
@ -1,134 +1,101 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"io"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"somehole.com/common/oauth2/session"
|
"somehole.com/service/oauth2/session"
|
||||||
|
"somehole.com/service/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
const TokenEndpoint = "/token"
|
type TokenRequestBuilder struct {
|
||||||
|
allowedMethods []string
|
||||||
type TokenError uint32
|
header struct {
|
||||||
|
router.Header
|
||||||
const (
|
}
|
||||||
TokenOk TokenError = iota
|
values struct {
|
||||||
TokenErrorUnimplemented
|
router.Values
|
||||||
TokenErrorUnauthorized
|
ResponseType string `form:"response_type"`
|
||||||
TokenErrorServerError
|
ClientId string `form:"client_id"`
|
||||||
TokenErrorSlowDown
|
Code session.Code `form:"code"`
|
||||||
TokenErrorPending
|
}
|
||||||
)
|
|
||||||
|
|
||||||
func (te TokenError) Ok() bool {
|
|
||||||
return te == TokenOk
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (te TokenError) HttpStatus() (code int) {
|
func (trb TokenRequestBuilder) RequestBuilder() router.RequestBuilder {
|
||||||
switch te {
|
return &trb
|
||||||
case TokenOk:
|
|
||||||
code = http.StatusOK
|
|
||||||
case TokenErrorUnimplemented:
|
|
||||||
code = http.StatusInternalServerError
|
|
||||||
case TokenErrorUnauthorized:
|
|
||||||
code = http.StatusUnauthorized
|
|
||||||
case TokenErrorServerError:
|
|
||||||
code = http.StatusInternalServerError
|
|
||||||
case TokenErrorSlowDown:
|
|
||||||
code = http.StatusBadRequest
|
|
||||||
case TokenErrorPending:
|
|
||||||
code = http.StatusBadRequest
|
|
||||||
default:
|
|
||||||
code = http.StatusInternalServerError
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (te TokenError) String() (out string) {
|
func (trb *TokenRequestBuilder) Allowed(method string) (errRes router.ErrorResponse) {
|
||||||
switch te {
|
var ok bool
|
||||||
case TokenOk:
|
for _, m := range trb.allowedMethods {
|
||||||
out = "authenticated"
|
if m == method {
|
||||||
case TokenErrorUnimplemented:
|
ok = true
|
||||||
out = "token server unimplemented"
|
}
|
||||||
case TokenErrorUnauthorized:
|
|
||||||
out = "user unauthorized"
|
|
||||||
case TokenErrorServerError:
|
|
||||||
out = "internal server error"
|
|
||||||
case TokenErrorSlowDown:
|
|
||||||
out = "slow down"
|
|
||||||
case TokenErrorPending:
|
|
||||||
out = "authorization pending"
|
|
||||||
default:
|
|
||||||
out = "unhandled error"
|
|
||||||
}
|
}
|
||||||
return
|
if !ok {
|
||||||
|
return ErrorMethodNotAllowed
|
||||||
|
}
|
||||||
|
return Ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (te TokenError) ErrorResponse() []byte {
|
func (trb *TokenRequestBuilder) Header(header router.Header) (errRes router.ErrorResponse) {
|
||||||
var msg string
|
err := trb.header.Header.Parse(header)
|
||||||
switch te {
|
if err != nil {
|
||||||
case TokenErrorSlowDown:
|
return ErrorBadRequest
|
||||||
msg = "slow_down"
|
|
||||||
case TokenErrorPending:
|
|
||||||
msg = "authorization_pending"
|
|
||||||
default:
|
|
||||||
msg = "internal_server_error"
|
|
||||||
}
|
}
|
||||||
return mustMarshalJson(struct {
|
return Ok
|
||||||
Error string `json:"error"`
|
|
||||||
}{
|
|
||||||
Error: msg,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
func (trb *TokenRequestBuilder) Body(body io.ReadCloser) (errRes router.ErrorResponse) {
|
||||||
ResponseType string
|
defer body.Close()
|
||||||
ClientId string
|
return Ok
|
||||||
Code session.Code
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr TokenRequest) Request() Request {
|
func (trb *TokenRequestBuilder) Values(values router.Values) (errRes router.ErrorResponse) {
|
||||||
return &tr
|
err := trb.values.Values.Parse(values)
|
||||||
}
|
if err != nil {
|
||||||
|
return ErrorBadRequest
|
||||||
func (tr *TokenRequest) Parse(data *url.Values) (err error) {
|
|
||||||
if !data.Has("ClientId") {
|
|
||||||
err = fmt.Errorf("missing required client_id paramater")
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
tr.ClientId = data.Get("client_id")
|
return Ok
|
||||||
if data.Has("code") {
|
|
||||||
tr.Code = session.Code(data.Get("code"))
|
|
||||||
}
|
|
||||||
if data.Has("response_type") {
|
|
||||||
tr.ResponseType = data.Get("response_type")
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
VerificationUri string `json:"verification_uri"`
|
header struct {
|
||||||
UserCode session.Code `json:"user_code"`
|
router.Header
|
||||||
DeviceCode session.Code `json:"device_code"`
|
}
|
||||||
Interval int `json:"interval"`
|
Body struct {
|
||||||
ExpiresIn int `json:"expires_in"`
|
VerificationUri string `json:"verification_uri"`
|
||||||
|
UserCode session.Code `json:"user_code"`
|
||||||
|
DeviceCode session.Code `json:"device_code"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tr *TokenResponse) Response() []byte {
|
func (tr TokenResponse) Response() router.Response {
|
||||||
return mustMarshalJson(tr)
|
return &tr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TokenResponse) Header() (header router.Header) {
|
||||||
|
if tr.header.Header == nil {
|
||||||
|
tr.header.Header.Parse(tr.header)
|
||||||
|
}
|
||||||
|
return tr.header.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *TokenResponse) Bytes() []byte {
|
||||||
|
return mustMarshalJson(tr.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
type UnimplementedTokenServer struct{}
|
type UnimplementedTokenServer struct{}
|
||||||
|
|
||||||
func (u UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {}
|
func (UnimplementedTokenServer) mustEmbedUnimplementedTokenServer() {}
|
||||||
|
|
||||||
func (u UnimplementedTokenServer) Token(req *TokenRequest) (res *TokenResponse, errRes ErrorResponse) {
|
func (UnimplementedTokenServer) Token(req *TokenRequestBuilder) (res *TokenResponse, errRes router.ErrorResponse) {
|
||||||
errRes = TokenErrorUnimplemented
|
errRes = ErrorNotImplemented
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenServer interface {
|
type TokenServer interface {
|
||||||
mustEmbedUnimplementedTokenServer()
|
mustEmbedUnimplementedTokenServer()
|
||||||
Token(*TokenRequest) (*TokenResponse, ErrorResponse)
|
Token(*TokenRequestBuilder) (*TokenResponse, router.ErrorResponse)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user