diff options
Diffstat (limited to 'app/controllers/oauth2_device.go')
-rw-r--r-- | app/controllers/oauth2_device.go | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/app/controllers/oauth2_device.go b/app/controllers/oauth2_device.go new file mode 100644 index 0000000..0ddf653 --- /dev/null +++ b/app/controllers/oauth2_device.go | |||
@@ -0,0 +1,129 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "crypto/subtle" | ||
5 | "fmt" | ||
6 | "net/http" | ||
7 | "strconv" | ||
8 | "time" | ||
9 | |||
10 | "code.crute.us/mcrute/ssh-proxy/app" | ||
11 | "code.crute.us/mcrute/ssh-proxy/app/models" | ||
12 | "github.com/labstack/echo/v4" | ||
13 | ) | ||
14 | |||
15 | func badRequest(c echo.Context, e models.AuthorizationError, d string) error { | ||
16 | return c.JSON(http.StatusBadRequest, models.Oauth2Error{ | ||
17 | Type: e, | ||
18 | Description: d, | ||
19 | }) | ||
20 | } | ||
21 | |||
22 | type OAuth2DeviceController[T app.AppSession] struct { | ||
23 | Logger echo.Logger | ||
24 | OauthClients models.OauthClientStore | ||
25 | AuthSessions models.AuthSessionStore | ||
26 | Hostname string | ||
27 | PollSeconds int | ||
28 | SessionExpiration time.Duration | ||
29 | } | ||
30 | |||
31 | func (a *OAuth2DeviceController[T]) HandleStart(c echo.Context) error { | ||
32 | ctx := c.Request().Context() | ||
33 | |||
34 | var form models.AuthorizationRequest | ||
35 | if err := (&echo.DefaultBinder{}).BindBody(c, &form); err != nil { | ||
36 | a.Logger.Errorf("Unable to parse form data: %s", err) | ||
37 | return badRequest(c, models.ErrInvalidRequest, "") | ||
38 | } | ||
39 | |||
40 | client, err := a.OauthClients.Get(ctx, form.ClientId) | ||
41 | if err != nil { | ||
42 | a.Logger.Errorf("Unable to find client ID '%s': %s", form.ClientId, err) | ||
43 | return badRequest(c, models.ErrUnauthorizedClient, "") | ||
44 | } | ||
45 | |||
46 | if len(form.Challenge) <= 16 { | ||
47 | return badRequest(c, models.ErrInvalidRequest, | ||
48 | "code_challenge is too short, minimum length is 16 bytes") | ||
49 | } | ||
50 | |||
51 | if form.ChallengeMethod != models.ChallengeS256 { | ||
52 | return badRequest(c, models.ErrInvalidRequest, | ||
53 | "code_challenge_method invalid, only S256 supported") | ||
54 | } | ||
55 | |||
56 | session := models.NewAuthSession(client.Id, time.Now().Add(a.SessionExpiration)) | ||
57 | session.SetChallenge(form.Challenge, form.ChallengeMethod) | ||
58 | session.SetScopeString(form.Scope) | ||
59 | |||
60 | if !session.HasAnyScopes() { | ||
61 | return badRequest(c, models.ErrInvalidRequest, "one or more scopes required") | ||
62 | } | ||
63 | |||
64 | for _, s := range session.Scope { | ||
65 | if s != "ssh:proxy" && s != "ca:issue" { | ||
66 | return badRequest(c, models.ErrInvalidScope, fmt.Sprintf("scope %s is not recognized", s)) | ||
67 | } | ||
68 | } | ||
69 | |||
70 | if err := a.AuthSessions.Insert(ctx, session); err != nil { | ||
71 | a.Logger.Errorf("Error inserting auth session", err) | ||
72 | return c.NoContent(http.StatusInternalServerError) | ||
73 | } | ||
74 | |||
75 | return c.JSON(http.StatusOK, models.DeviceAuthorizationResponse{ | ||
76 | DeviceCode: session.DeviceCode, | ||
77 | UserCode: session.UserCode, | ||
78 | VerificationUri: fmt.Sprintf("%s/login", a.Hostname), | ||
79 | VerificationUriComplete: fmt.Sprintf("%s/login?code=%s", a.Hostname, session.UserCode), | ||
80 | ExpiresIn: int(time.Until(session.Expires).Seconds()), | ||
81 | Interval: a.PollSeconds, | ||
82 | }) | ||
83 | } | ||
84 | |||
85 | func (a *OAuth2DeviceController[T]) HandleToken(c echo.Context) error { | ||
86 | ctx := c.Request().Context() | ||
87 | |||
88 | var form models.DeviceAccessTokenRequest | ||
89 | if err := (&echo.DefaultBinder{}).BindBody(c, &form); err != nil { | ||
90 | a.Logger.Errorf("Unable to parse form data: %s", err) | ||
91 | return badRequest(c, models.ErrInvalidRequest, "") | ||
92 | } | ||
93 | |||
94 | session, err := a.AuthSessions.Get(ctx, form.DeviceCode) | ||
95 | if err != nil { | ||
96 | return c.NoContent(http.StatusNotFound) | ||
97 | } | ||
98 | |||
99 | if form.GrantType != models.DEVICE_CODE_GRANT_TYPE { | ||
100 | return badRequest(c, models.ErrUnsupportedGrantType, "") | ||
101 | } | ||
102 | |||
103 | if subtle.ConstantTimeCompare([]byte(session.ClientId), []byte(form.ClientId)) != 1 { | ||
104 | return badRequest(c, models.ErrUnauthorizedClient, "") | ||
105 | } | ||
106 | |||
107 | if time.Now().After(session.Expires) { | ||
108 | return badRequest(c, models.ErrExpiredToken, "") | ||
109 | } | ||
110 | |||
111 | verifier := &models.PKCEChallenge{Verifier: form.CodeVerifier} | ||
112 | if verifier.EqualString(session.Challenge) { | ||
113 | return badRequest(c, models.ErrInvalidGrant, "") // Per RFC7636 4.6 | ||
114 | } | ||
115 | |||
116 | if session.IsRegistration { | ||
117 | return badRequest(c, models.ErrInvalidGrant, "") | ||
118 | } | ||
119 | |||
120 | if session.AccessCode == "" { | ||
121 | return badRequest(c, models.ErrAuthorizationPending, "") | ||
122 | } | ||
123 | |||
124 | return c.JSON(http.StatusOK, models.AccessTokenResponse{ | ||
125 | AccessToken: session.AccessCode, | ||
126 | TokenType: "Bearer", | ||
127 | ExpiresIn: strconv.FormatInt(int64(time.Until(session.Expires).Seconds()), 10), | ||
128 | }) | ||
129 | } | ||