diff options
Diffstat (limited to 'app/controllers/oauth2_device.go')
-rw-r--r-- | app/controllers/oauth2_device.go | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/app/controllers/oauth2_device.go b/app/controllers/oauth2_device.go index 0ddf653..c431495 100644 --- a/app/controllers/oauth2_device.go +++ b/app/controllers/oauth2_device.go | |||
@@ -9,7 +9,23 @@ import ( | |||
9 | 9 | ||
10 | "code.crute.us/mcrute/ssh-proxy/app" | 10 | "code.crute.us/mcrute/ssh-proxy/app" |
11 | "code.crute.us/mcrute/ssh-proxy/app/models" | 11 | "code.crute.us/mcrute/ssh-proxy/app/models" |
12 | |||
12 | "github.com/labstack/echo/v4" | 13 | "github.com/labstack/echo/v4" |
14 | "github.com/prometheus/client_golang/prometheus" | ||
15 | "github.com/prometheus/client_golang/prometheus/promauto" | ||
16 | ) | ||
17 | |||
18 | var ( | ||
19 | oauth2DeviceError = promauto.NewCounterVec(prometheus.CounterOpts{ | ||
20 | Namespace: "ssh_proxy", | ||
21 | Name: "oauth2_device_error", | ||
22 | Help: "Total number of errors during oauth2 device operations", | ||
23 | }, []string{"type"}) | ||
24 | oauth2DeviceSuccess = promauto.NewCounter(prometheus.CounterOpts{ | ||
25 | Namespace: "ssh_proxy", | ||
26 | Name: "oauth2_device_success", | ||
27 | Help: "Total number of successful oauth2 device auths", | ||
28 | }) | ||
13 | ) | 29 | ) |
14 | 30 | ||
15 | func badRequest(c echo.Context, e models.AuthorizationError, d string) error { | 31 | func badRequest(c echo.Context, e models.AuthorizationError, d string) error { |
@@ -40,15 +56,18 @@ func (a *OAuth2DeviceController[T]) HandleStart(c echo.Context) error { | |||
40 | client, err := a.OauthClients.Get(ctx, form.ClientId) | 56 | client, err := a.OauthClients.Get(ctx, form.ClientId) |
41 | if err != nil { | 57 | if err != nil { |
42 | a.Logger.Errorf("Unable to find client ID '%s': %s", form.ClientId, err) | 58 | a.Logger.Errorf("Unable to find client ID '%s': %s", form.ClientId, err) |
59 | oauth2DeviceError.With(prometheus.Labels{"type": "invalid_client_id"}).Inc() | ||
43 | return badRequest(c, models.ErrUnauthorizedClient, "") | 60 | return badRequest(c, models.ErrUnauthorizedClient, "") |
44 | } | 61 | } |
45 | 62 | ||
46 | if len(form.Challenge) <= 16 { | 63 | if len(form.Challenge) <= 16 { |
64 | oauth2DeviceError.With(prometheus.Labels{"type": "challenge_length"}).Inc() | ||
47 | return badRequest(c, models.ErrInvalidRequest, | 65 | return badRequest(c, models.ErrInvalidRequest, |
48 | "code_challenge is too short, minimum length is 16 bytes") | 66 | "code_challenge is too short, minimum length is 16 bytes") |
49 | } | 67 | } |
50 | 68 | ||
51 | if form.ChallengeMethod != models.ChallengeS256 { | 69 | if form.ChallengeMethod != models.ChallengeS256 { |
70 | oauth2DeviceError.With(prometheus.Labels{"type": "challenge_type"}).Inc() | ||
52 | return badRequest(c, models.ErrInvalidRequest, | 71 | return badRequest(c, models.ErrInvalidRequest, |
53 | "code_challenge_method invalid, only S256 supported") | 72 | "code_challenge_method invalid, only S256 supported") |
54 | } | 73 | } |
@@ -58,11 +77,13 @@ func (a *OAuth2DeviceController[T]) HandleStart(c echo.Context) error { | |||
58 | session.SetScopeString(form.Scope) | 77 | session.SetScopeString(form.Scope) |
59 | 78 | ||
60 | if !session.HasAnyScopes() { | 79 | if !session.HasAnyScopes() { |
80 | oauth2DeviceError.With(prometheus.Labels{"type": "no_scopes"}).Inc() | ||
61 | return badRequest(c, models.ErrInvalidRequest, "one or more scopes required") | 81 | return badRequest(c, models.ErrInvalidRequest, "one or more scopes required") |
62 | } | 82 | } |
63 | 83 | ||
64 | for _, s := range session.Scope { | 84 | for _, s := range session.Scope { |
65 | if s != "ssh:proxy" && s != "ca:issue" { | 85 | if s != "ssh:proxy" && s != "ca:issue" { |
86 | oauth2DeviceError.With(prometheus.Labels{"type": "invalid_scope"}).Inc() | ||
66 | return badRequest(c, models.ErrInvalidScope, fmt.Sprintf("scope %s is not recognized", s)) | 87 | return badRequest(c, models.ErrInvalidScope, fmt.Sprintf("scope %s is not recognized", s)) |
67 | } | 88 | } |
68 | } | 89 | } |
@@ -93,27 +114,33 @@ func (a *OAuth2DeviceController[T]) HandleToken(c echo.Context) error { | |||
93 | 114 | ||
94 | session, err := a.AuthSessions.Get(ctx, form.DeviceCode) | 115 | session, err := a.AuthSessions.Get(ctx, form.DeviceCode) |
95 | if err != nil { | 116 | if err != nil { |
117 | oauth2DeviceError.With(prometheus.Labels{"type": "no_auth_session"}).Inc() | ||
96 | return c.NoContent(http.StatusNotFound) | 118 | return c.NoContent(http.StatusNotFound) |
97 | } | 119 | } |
98 | 120 | ||
99 | if form.GrantType != models.DEVICE_CODE_GRANT_TYPE { | 121 | if form.GrantType != models.DEVICE_CODE_GRANT_TYPE { |
122 | oauth2DeviceError.With(prometheus.Labels{"type": "invalid_grant_type"}).Inc() | ||
100 | return badRequest(c, models.ErrUnsupportedGrantType, "") | 123 | return badRequest(c, models.ErrUnsupportedGrantType, "") |
101 | } | 124 | } |
102 | 125 | ||
103 | if subtle.ConstantTimeCompare([]byte(session.ClientId), []byte(form.ClientId)) != 1 { | 126 | if subtle.ConstantTimeCompare([]byte(session.ClientId), []byte(form.ClientId)) != 1 { |
127 | oauth2DeviceError.With(prometheus.Labels{"type": "client_id_mismatch"}).Inc() | ||
104 | return badRequest(c, models.ErrUnauthorizedClient, "") | 128 | return badRequest(c, models.ErrUnauthorizedClient, "") |
105 | } | 129 | } |
106 | 130 | ||
107 | if time.Now().After(session.Expires) { | 131 | if time.Now().After(session.Expires) { |
132 | oauth2DeviceError.With(prometheus.Labels{"type": "expired_session"}).Inc() | ||
108 | return badRequest(c, models.ErrExpiredToken, "") | 133 | return badRequest(c, models.ErrExpiredToken, "") |
109 | } | 134 | } |
110 | 135 | ||
111 | verifier := &models.PKCEChallenge{Verifier: form.CodeVerifier} | 136 | verifier := &models.PKCEChallenge{Verifier: form.CodeVerifier} |
112 | if verifier.EqualString(session.Challenge) { | 137 | if verifier.EqualString(session.Challenge) { |
138 | oauth2DeviceError.With(prometheus.Labels{"type": "pkce_mismatch"}).Inc() | ||
113 | return badRequest(c, models.ErrInvalidGrant, "") // Per RFC7636 4.6 | 139 | return badRequest(c, models.ErrInvalidGrant, "") // Per RFC7636 4.6 |
114 | } | 140 | } |
115 | 141 | ||
116 | if session.IsRegistration { | 142 | if session.IsRegistration { |
143 | oauth2DeviceError.With(prometheus.Labels{"type": "is_registration_session"}).Inc() | ||
117 | return badRequest(c, models.ErrInvalidGrant, "") | 144 | return badRequest(c, models.ErrInvalidGrant, "") |
118 | } | 145 | } |
119 | 146 | ||
@@ -121,6 +148,7 @@ func (a *OAuth2DeviceController[T]) HandleToken(c echo.Context) error { | |||
121 | return badRequest(c, models.ErrAuthorizationPending, "") | 148 | return badRequest(c, models.ErrAuthorizationPending, "") |
122 | } | 149 | } |
123 | 150 | ||
151 | oauth2DeviceSuccess.Inc() | ||
124 | return c.JSON(http.StatusOK, models.AccessTokenResponse{ | 152 | return c.JSON(http.StatusOK, models.AccessTokenResponse{ |
125 | AccessToken: session.AccessCode, | 153 | AccessToken: session.AccessCode, |
126 | TokenType: "Bearer", | 154 | TokenType: "Bearer", |