oauth2/router.go

73 lines
2.0 KiB
Go
Raw Normal View History

2024-10-09 00:33:23 +00:00
package oauth2
import (
"fmt"
"somehole.com/common/log"
"somehole.com/service/oauth2/server"
"somehole.com/service/router"
)
type Endpoints struct {
CallbackEndpoint string
TokenEndpoint string
}
var defaultEndpoints = Endpoints{
CallbackEndpoint: "/callback",
TokenEndpoint: "/token",
}
type Router struct {
*router.Router
log.Logger
Endpoints
}
func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoints) (ro *Router) {
var endpoints Endpoints
if r == nil {
r = router.NewRouter(nil, nil)
}
if len(optionalEndpoints) == 1 {
endpoints = optionalEndpoints[0]
} else {
endpoints = defaultEndpoints
}
r.AddRequiredRoutes([]string{endpoints.CallbackEndpoint, endpoints.TokenEndpoint})
if logger == nil {
logger = log.NewPlainLogger()
}
return &Router{
Router: r,
Logger: logger,
Endpoints: endpoints,
}
}
func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) {
pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}}
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
callbackRequest, ok := req.(*server.CallbackRequestBuilder)
if !ok {
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
}
res, errRes = srv.Callback(callbackRequest)
return
}
ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server))
}
func (ro *Router) RegisterTokenServer(srv server.TokenServer) {
pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}}
server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) {
callbackRequest, ok := req.(*server.TokenRequestBuilder)
if !ok {
panic(fmt.Errorf("expected TokenRequest, got %T", req))
}
res, errRes = srv.Token(callbackRequest)
return
}
ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server))
}