53 lines
1.4 KiB
Go
53 lines
1.4 KiB
Go
package oauth2
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"somehole.com/common/log"
|
|
"somehole.com/common/oauth2/server"
|
|
)
|
|
|
|
type Server struct {
|
|
*http.ServeMux
|
|
basePath string
|
|
Logger log.Logger
|
|
}
|
|
|
|
func NewServer(mux *http.ServeMux, basePath string, logger log.Logger) *Server {
|
|
if mux == nil {
|
|
mux = http.NewServeMux()
|
|
}
|
|
return &Server{
|
|
ServeMux: mux,
|
|
basePath: basePath,
|
|
Logger: logger,
|
|
}
|
|
}
|
|
|
|
func (s *Server) Handle(pattern string, handler http.Handler) {
|
|
s.ServeMux.Handle(s.basePath+pattern, handler)
|
|
}
|
|
|
|
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
|
s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
|
callbackRequest, ok := req.(*server.CallbackRequest)
|
|
if !ok {
|
|
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
|
}
|
|
res, errRes = srv.Callback(callbackRequest)
|
|
return
|
|
}))
|
|
}
|
|
|
|
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
|
|
s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
|
tokenRequest, ok := req.(*server.TokenRequest)
|
|
if !ok {
|
|
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
|
}
|
|
res, errRes = srv.Token(tokenRequest)
|
|
return
|
|
}))
|
|
}
|