package server import ( "fmt" "net/http" "net/url" "somehole.com/common/oauth2/session" ) const CallbackEndpoint = "/callback" type CallbackError uint32 const ( CallbackOk CallbackError = iota CallbackErrorUnimplemented CallbackErrorUnauthorized CallbackErrorServerError ) func (ce CallbackError) Ok() bool { return ce == CallbackOk } func (ce CallbackError) HttpStatus() (code int) { switch ce { case CallbackOk: code = http.StatusOK case CallbackErrorUnimplemented: code = http.StatusInternalServerError case CallbackErrorUnauthorized: code = http.StatusUnauthorized case CallbackErrorServerError: code = http.StatusInternalServerError default: code = http.StatusInternalServerError } return } func (ce CallbackError) String() (out string) { switch ce { case CallbackOk: out = "authenticated" case CallbackErrorUnimplemented: out = "callback server unimplemented" case CallbackErrorUnauthorized: out = "user unauthorized" case CallbackErrorServerError: out = "internal server error" default: out = "unhandled error" } return } func (ce CallbackError) ErrorResponse() []byte { var msg string switch ce { default: msg = "internal_server_error" } return mustMarshalJson(struct { Error string `json:"error"` }{ Error: msg, }) } type CallbackRequest struct { State session.State Code session.Code } func (cr CallbackRequest) Request() Request { return &cr } func (cr *CallbackRequest) Parse(data *url.Values) (err error) { if !data.Has("code") { err = fmt.Errorf("missing code paramater") return } if !data.Has("state") { err = fmt.Errorf("missing state parameter") return } cr.State = session.State(data.Get("state")) cr.Code = session.Code(data.Get("code")) return } type CallbackResponse struct { Status int `json:"-"` Message string `json:"message"` } func (cr *CallbackResponse) HttpStatus() (code int) { return cr.Status } func (cr *CallbackResponse) Response() []byte { return mustMarshalJson(cr) } type UnimplementedCallbackServer struct{} func (u UnimplementedCallbackServer) mustEmbedUnimplementedCallbackServer() {} func (u UnimplementedCallbackServer) Callback(req *CallbackRequest) (res *CallbackResponse, errRes ErrorResponse) { errRes = CallbackErrorUnimplemented return } type CallbackServer interface { mustEmbedUnimplementedCallbackServer() Callback(*CallbackRequest) (*CallbackResponse, ErrorResponse) }