package oauth2 import ( "fmt" "somehole.com/common/log" "somehole.com/service/oauth2/server" "somehole.com/service/router" ) type Endpoints struct { CallbackEndpoint string TokenEndpoint string } var defaultEndpoints = Endpoints{ CallbackEndpoint: "/callback", TokenEndpoint: "/token", } type Router struct { *router.Router log.Logger Endpoints } func NewRouter(r *router.Router, logger log.Logger, optionalEndpoints ...Endpoints) (ro *Router) { var endpoints Endpoints if r == nil { r = router.NewRouter(nil, nil) } if len(optionalEndpoints) == 1 { endpoints = optionalEndpoints[0] } else { endpoints = defaultEndpoints } r.AddRequiredRoutes([]string{endpoints.CallbackEndpoint, endpoints.TokenEndpoint}) if logger == nil { logger = log.NewPlainLogger() } return &Router{ Router: r, Logger: logger, Endpoints: endpoints, } } func (ro *Router) RegisterCallbackServer(srv server.CallbackServer) { pro := router.Prototype{PrototypeRequestBuilder: server.CallbackRequestBuilder{}, PrototypeResponse: server.CallbackResponse{}} server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) { callbackRequest, ok := req.(*server.CallbackRequestBuilder) if !ok { panic(fmt.Errorf("expected CallbackRequest, got %T", req)) } res, errRes = srv.Callback(callbackRequest) return } ro.Register(ro.CallbackEndpoint, router.NewServer(pro, ro.Logger, server)) } func (ro *Router) RegisterTokenServer(srv server.TokenServer) { pro := router.Prototype{PrototypeRequestBuilder: server.TokenRequestBuilder{}, PrototypeResponse: server.TokenResponse{}} server := func(req router.RequestBuilder, res router.Response) (errRes router.ErrorResponse) { callbackRequest, ok := req.(*server.TokenRequestBuilder) if !ok { panic(fmt.Errorf("expected TokenRequest, got %T", req)) } res, errRes = srv.Token(callbackRequest) return } ro.Register(ro.TokenEndpoint, router.NewServer(pro, ro.Logger, server)) }