summaryrefslogtreecommitdiff
path: root/app/controllers/oauth2_device.go
diff options
context:
space:
mode:
Diffstat (limited to 'app/controllers/oauth2_device.go')
-rw-r--r--app/controllers/oauth2_device.go28
1 files changed, 28 insertions, 0 deletions
diff --git a/app/controllers/oauth2_device.go b/app/controllers/oauth2_device.go
index 0ddf653..c431495 100644
--- a/app/controllers/oauth2_device.go
+++ b/app/controllers/oauth2_device.go
@@ -9,7 +9,23 @@ import (
9 9
10 "code.crute.us/mcrute/ssh-proxy/app" 10 "code.crute.us/mcrute/ssh-proxy/app"
11 "code.crute.us/mcrute/ssh-proxy/app/models" 11 "code.crute.us/mcrute/ssh-proxy/app/models"
12
12 "github.com/labstack/echo/v4" 13 "github.com/labstack/echo/v4"
14 "github.com/prometheus/client_golang/prometheus"
15 "github.com/prometheus/client_golang/prometheus/promauto"
16)
17
18var (
19 oauth2DeviceError = promauto.NewCounterVec(prometheus.CounterOpts{
20 Namespace: "ssh_proxy",
21 Name: "oauth2_device_error",
22 Help: "Total number of errors during oauth2 device operations",
23 }, []string{"type"})
24 oauth2DeviceSuccess = promauto.NewCounter(prometheus.CounterOpts{
25 Namespace: "ssh_proxy",
26 Name: "oauth2_device_success",
27 Help: "Total number of successful oauth2 device auths",
28 })
13) 29)
14 30
15func badRequest(c echo.Context, e models.AuthorizationError, d string) error { 31func badRequest(c echo.Context, e models.AuthorizationError, d string) error {
@@ -40,15 +56,18 @@ func (a *OAuth2DeviceController[T]) HandleStart(c echo.Context) error {
40 client, err := a.OauthClients.Get(ctx, form.ClientId) 56 client, err := a.OauthClients.Get(ctx, form.ClientId)
41 if err != nil { 57 if err != nil {
42 a.Logger.Errorf("Unable to find client ID '%s': %s", form.ClientId, err) 58 a.Logger.Errorf("Unable to find client ID '%s': %s", form.ClientId, err)
59 oauth2DeviceError.With(prometheus.Labels{"type": "invalid_client_id"}).Inc()
43 return badRequest(c, models.ErrUnauthorizedClient, "") 60 return badRequest(c, models.ErrUnauthorizedClient, "")
44 } 61 }
45 62
46 if len(form.Challenge) <= 16 { 63 if len(form.Challenge) <= 16 {
64 oauth2DeviceError.With(prometheus.Labels{"type": "challenge_length"}).Inc()
47 return badRequest(c, models.ErrInvalidRequest, 65 return badRequest(c, models.ErrInvalidRequest,
48 "code_challenge is too short, minimum length is 16 bytes") 66 "code_challenge is too short, minimum length is 16 bytes")
49 } 67 }
50 68
51 if form.ChallengeMethod != models.ChallengeS256 { 69 if form.ChallengeMethod != models.ChallengeS256 {
70 oauth2DeviceError.With(prometheus.Labels{"type": "challenge_type"}).Inc()
52 return badRequest(c, models.ErrInvalidRequest, 71 return badRequest(c, models.ErrInvalidRequest,
53 "code_challenge_method invalid, only S256 supported") 72 "code_challenge_method invalid, only S256 supported")
54 } 73 }
@@ -58,11 +77,13 @@ func (a *OAuth2DeviceController[T]) HandleStart(c echo.Context) error {
58 session.SetScopeString(form.Scope) 77 session.SetScopeString(form.Scope)
59 78
60 if !session.HasAnyScopes() { 79 if !session.HasAnyScopes() {
80 oauth2DeviceError.With(prometheus.Labels{"type": "no_scopes"}).Inc()
61 return badRequest(c, models.ErrInvalidRequest, "one or more scopes required") 81 return badRequest(c, models.ErrInvalidRequest, "one or more scopes required")
62 } 82 }
63 83
64 for _, s := range session.Scope { 84 for _, s := range session.Scope {
65 if s != "ssh:proxy" && s != "ca:issue" { 85 if s != "ssh:proxy" && s != "ca:issue" {
86 oauth2DeviceError.With(prometheus.Labels{"type": "invalid_scope"}).Inc()
66 return badRequest(c, models.ErrInvalidScope, fmt.Sprintf("scope %s is not recognized", s)) 87 return badRequest(c, models.ErrInvalidScope, fmt.Sprintf("scope %s is not recognized", s))
67 } 88 }
68 } 89 }
@@ -93,27 +114,33 @@ func (a *OAuth2DeviceController[T]) HandleToken(c echo.Context) error {
93 114
94 session, err := a.AuthSessions.Get(ctx, form.DeviceCode) 115 session, err := a.AuthSessions.Get(ctx, form.DeviceCode)
95 if err != nil { 116 if err != nil {
117 oauth2DeviceError.With(prometheus.Labels{"type": "no_auth_session"}).Inc()
96 return c.NoContent(http.StatusNotFound) 118 return c.NoContent(http.StatusNotFound)
97 } 119 }
98 120
99 if form.GrantType != models.DEVICE_CODE_GRANT_TYPE { 121 if form.GrantType != models.DEVICE_CODE_GRANT_TYPE {
122 oauth2DeviceError.With(prometheus.Labels{"type": "invalid_grant_type"}).Inc()
100 return badRequest(c, models.ErrUnsupportedGrantType, "") 123 return badRequest(c, models.ErrUnsupportedGrantType, "")
101 } 124 }
102 125
103 if subtle.ConstantTimeCompare([]byte(session.ClientId), []byte(form.ClientId)) != 1 { 126 if subtle.ConstantTimeCompare([]byte(session.ClientId), []byte(form.ClientId)) != 1 {
127 oauth2DeviceError.With(prometheus.Labels{"type": "client_id_mismatch"}).Inc()
104 return badRequest(c, models.ErrUnauthorizedClient, "") 128 return badRequest(c, models.ErrUnauthorizedClient, "")
105 } 129 }
106 130
107 if time.Now().After(session.Expires) { 131 if time.Now().After(session.Expires) {
132 oauth2DeviceError.With(prometheus.Labels{"type": "expired_session"}).Inc()
108 return badRequest(c, models.ErrExpiredToken, "") 133 return badRequest(c, models.ErrExpiredToken, "")
109 } 134 }
110 135
111 verifier := &models.PKCEChallenge{Verifier: form.CodeVerifier} 136 verifier := &models.PKCEChallenge{Verifier: form.CodeVerifier}
112 if verifier.EqualString(session.Challenge) { 137 if verifier.EqualString(session.Challenge) {
138 oauth2DeviceError.With(prometheus.Labels{"type": "pkce_mismatch"}).Inc()
113 return badRequest(c, models.ErrInvalidGrant, "") // Per RFC7636 4.6 139 return badRequest(c, models.ErrInvalidGrant, "") // Per RFC7636 4.6
114 } 140 }
115 141
116 if session.IsRegistration { 142 if session.IsRegistration {
143 oauth2DeviceError.With(prometheus.Labels{"type": "is_registration_session"}).Inc()
117 return badRequest(c, models.ErrInvalidGrant, "") 144 return badRequest(c, models.ErrInvalidGrant, "")
118 } 145 }
119 146
@@ -121,6 +148,7 @@ func (a *OAuth2DeviceController[T]) HandleToken(c echo.Context) error {
121 return badRequest(c, models.ErrAuthorizationPending, "") 148 return badRequest(c, models.ErrAuthorizationPending, "")
122 } 149 }
123 150
151 oauth2DeviceSuccess.Inc()
124 return c.JSON(http.StatusOK, models.AccessTokenResponse{ 152 return c.JSON(http.StatusOK, models.AccessTokenResponse{
125 AccessToken: session.AccessCode, 153 AccessToken: session.AccessCode,
126 TokenType: "Bearer", 154 TokenType: "Bearer",