73 lines
2.0 KiB
Go
73 lines
2.0 KiB
Go
|
package oauth2
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
|
||
|
"somehole.com/common/log"
|
||
|
"somehole.com/service/oauth2/server"
|
||
|
"somehole.com/service/router"
|
||
|
)
|
||
|
|
||
|
type Endpoints struct {
|
||
|
CallbackEndpoint string
|
||
|
TokenEndpoint string
|
||
|
}
|
||
|
|
||
|
var defaultEndpoints = Endpoints{
|
||
|
CallbackEndpoint: "/callback",
|
||
|
TokenEndpoint: "/token",
|
||
|
}
|
||
|
|
||
|
type Router struct {
|
||
|
*router.Router
|
||
|
log.Logger
|
||
|
Endpoints
|
||
|
}
|
||
|
|
||
|
func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoints) (ro *Router) {
|
||
|
var endpoints Endpoints
|
||
|
if r == nil {
|
||
|
r = router.NewRouter(nil, nil)
|
||
|
}
|
||
|
if len(optionalEndpoints) == 1 {
|
||
|
endpoints = optionalEndpoints[0]
|
||
|
} else {
|
||
|
endpoints = defaultEndpoints
|
||
|
}
|
||
|
r.AddRequiredRoutes([]string{endpoints.CallbackEndpoint, endpoints.TokenEndpoint})
|
||
|
if logger == nil {
|
||
|
logger = log.NewPlainLogger()
|
||
|
}
|
||
|
return &Router{
|
||
|
Router: r,
|
||
|
Logger: logger,
|
||
|
Endpoints: endpoints,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) {
|
||
|
pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}}
|
||
|
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
|
||
|
callbackRequest, ok := req.(*server.CallbackRequestBuilder)
|
||
|
if !ok {
|
||
|
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
||
|
}
|
||
|
res, errRes = srv.Callback(callbackRequest)
|
||
|
return
|
||
|
}
|
||
|
ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server))
|
||
|
}
|
||
|
|
||
|
func (ro *Router) RegisterTokenServer(srv server.TokenServer) {
|
||
|
pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}}
|
||
|
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
|
||
|
callbackRequest, ok := req.(*server.TokenRequestBuilder)
|
||
|
if !ok {
|
||
|
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
||
|
}
|
||
|
res, errRes = srv.Token(callbackRequest)
|
||
|
return
|
||
|
}
|
||
|
ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server))
|
||
|
}
|