summaryrefslogtreecommitdiff
path: root/app/controllers/oauth2_device.go
blob: 0ddf65390447867a13cb9d245980a749640a38b6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package controllers

import (
	"crypto/subtle"
	"fmt"
	"net/http"
	"strconv"
	"time"

	"code.crute.us/mcrute/ssh-proxy/app"
	"code.crute.us/mcrute/ssh-proxy/app/models"
	"github.com/labstack/echo/v4"
)

func badRequest(c echo.Context, e models.AuthorizationError, d string) error {
	return c.JSON(http.StatusBadRequest, models.Oauth2Error{
		Type:        e,
		Description: d,
	})
}

type OAuth2DeviceController[T app.AppSession] struct {
	Logger            echo.Logger
	OauthClients      models.OauthClientStore
	AuthSessions      models.AuthSessionStore
	Hostname          string
	PollSeconds       int
	SessionExpiration time.Duration
}

func (a *OAuth2DeviceController[T]) HandleStart(c echo.Context) error {
	ctx := c.Request().Context()

	var form models.AuthorizationRequest
	if err := (&echo.DefaultBinder{}).BindBody(c, &form); err != nil {
		a.Logger.Errorf("Unable to parse form data: %s", err)
		return badRequest(c, models.ErrInvalidRequest, "")
	}

	client, err := a.OauthClients.Get(ctx, form.ClientId)
	if err != nil {
		a.Logger.Errorf("Unable to find client ID '%s': %s", form.ClientId, err)
		return badRequest(c, models.ErrUnauthorizedClient, "")
	}

	if len(form.Challenge) <= 16 {
		return badRequest(c, models.ErrInvalidRequest,
			"code_challenge is too short, minimum length is 16 bytes")
	}

	if form.ChallengeMethod != models.ChallengeS256 {
		return badRequest(c, models.ErrInvalidRequest,
			"code_challenge_method invalid, only S256 supported")
	}

	session := models.NewAuthSession(client.Id, time.Now().Add(a.SessionExpiration))
	session.SetChallenge(form.Challenge, form.ChallengeMethod)
	session.SetScopeString(form.Scope)

	if !session.HasAnyScopes() {
		return badRequest(c, models.ErrInvalidRequest, "one or more scopes required")
	}

	for _, s := range session.Scope {
		if s != "ssh:proxy" && s != "ca:issue" {
			return badRequest(c, models.ErrInvalidScope, fmt.Sprintf("scope %s is not recognized", s))
		}
	}

	if err := a.AuthSessions.Insert(ctx, session); err != nil {
		a.Logger.Errorf("Error inserting auth session", err)
		return c.NoContent(http.StatusInternalServerError)
	}

	return c.JSON(http.StatusOK, models.DeviceAuthorizationResponse{
		DeviceCode:              session.DeviceCode,
		UserCode:                session.UserCode,
		VerificationUri:         fmt.Sprintf("%s/login", a.Hostname),
		VerificationUriComplete: fmt.Sprintf("%s/login?code=%s", a.Hostname, session.UserCode),
		ExpiresIn:               int(time.Until(session.Expires).Seconds()),
		Interval:                a.PollSeconds,
	})
}

func (a *OAuth2DeviceController[T]) HandleToken(c echo.Context) error {
	ctx := c.Request().Context()

	var form models.DeviceAccessTokenRequest
	if err := (&echo.DefaultBinder{}).BindBody(c, &form); err != nil {
		a.Logger.Errorf("Unable to parse form data: %s", err)
		return badRequest(c, models.ErrInvalidRequest, "")
	}

	session, err := a.AuthSessions.Get(ctx, form.DeviceCode)
	if err != nil {
		return c.NoContent(http.StatusNotFound)
	}

	if form.GrantType != models.DEVICE_CODE_GRANT_TYPE {
		return badRequest(c, models.ErrUnsupportedGrantType, "")
	}

	if subtle.ConstantTimeCompare([]byte(session.ClientId), []byte(form.ClientId)) != 1 {
		return badRequest(c, models.ErrUnauthorizedClient, "")
	}

	if time.Now().After(session.Expires) {
		return badRequest(c, models.ErrExpiredToken, "")
	}

	verifier := &models.PKCEChallenge{Verifier: form.CodeVerifier}
	if verifier.EqualString(session.Challenge) {
		return badRequest(c, models.ErrInvalidGrant, "") // Per RFC7636 4.6
	}

	if session.IsRegistration {
		return badRequest(c, models.ErrInvalidGrant, "")
	}

	if session.AccessCode == "" {
		return badRequest(c, models.ErrAuthorizationPending, "")
	}

	return c.JSON(http.StatusOK, models.AccessTokenResponse{
		AccessToken: session.AccessCode,
		TokenType:   "Bearer",
		ExpiresIn:   strconv.FormatInt(int64(time.Until(session.Expires).Seconds()), 10),
	})
}