oauth2/server.go
2024-10-02 13:45:42 -04:00

53 lines
1.4 KiB
Go

package oauth2
import (
"fmt"
"net/http"
"somehole.com/common/log"
"somehole.com/common/oauth2/server"
)
type Server struct {
*http.ServeMux
basePath string
Logger log.Logger
}
func NewServer(mux *http.ServeMux, basePath string, logger log.Logger) *Server {
if mux == nil {
mux = http.NewServeMux()
}
return &Server{
ServeMux: mux,
basePath: basePath,
Logger: logger,
}
}
func (s *Server) Handle(pattern string, handler http.Handler) {
s.ServeMux.Handle(s.basePath+pattern, handler)
}
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
s.Handle(server.CallbackEndpoint, server.NewServer(server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
callbackRequest, ok := req.(*server.CallbackRequest)
if !ok {
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
}
res, errRes = srv.Callback(callbackRequest)
return
}))
}
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
s.Handle(server.TokenEndpoint, server.NewServer(server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
tokenRequest, ok := req.(*server.TokenRequest)
if !ok {
panic(fmt.Errorf("expected TokenRequest, got %T", req))
}
res, errRes = srv.Token(tokenRequest)
return
}))
}