47 lines
1.2 KiB
Go
47 lines
1.2 KiB
Go
|
package oauth2
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
|
||
|
"somehole.com/common/log"
|
||
|
"somehole.com/common/oauth2/server"
|
||
|
)
|
||
|
|
||
|
type Server struct {
|
||
|
*http.ServeMux
|
||
|
Logger log.Logger
|
||
|
}
|
||
|
|
||
|
func NewServer(mux *http.ServeMux, logger log.Logger) *Server {
|
||
|
if mux == nil {
|
||
|
mux = http.NewServeMux()
|
||
|
}
|
||
|
return &Server{
|
||
|
ServeMux: mux,
|
||
|
Logger: logger,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Server) RegisterCallbackServer(srv server.CallbackServer) {
|
||
|
s.Handle(server.CallbackEndpoint, server.NewServer(&server.CallbackRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||
|
callbackRequest, ok := req.(*server.CallbackRequest)
|
||
|
if !ok {
|
||
|
panic(fmt.Errorf("expected CallbackRequest, got %T", req))
|
||
|
}
|
||
|
res, errRes = srv.Callback(callbackRequest)
|
||
|
return
|
||
|
}))
|
||
|
}
|
||
|
|
||
|
func (s *Server) RegisterTokenServer(srv server.TokenServer) {
|
||
|
s.Handle(server.TokenEndpoint, server.NewServer(&server.TokenRequest{}, []string{http.MethodPost}, s.Logger, func(req server.Request) (res server.Response, errRes server.ErrorResponse) {
|
||
|
tokenRequest, ok := req.(*server.TokenRequest)
|
||
|
if !ok {
|
||
|
panic(fmt.Errorf("expected TokenRequest, got %T", req))
|
||
|
}
|
||
|
res, errRes = srv.Token(tokenRequest)
|
||
|
return
|
||
|
}))
|
||
|
}
|