package client import ( "context" "encoding/json" "fmt" "net/http" "strings" "time" "code.crute.us/mcrute/ssh-proxy/app/models" "github.com/google/go-querystring/query" ) // Oauth2PKCEDeviceClient is not safe for concurrent use and should be // created anew for each request. type Oauth2PKCEDeviceClient struct { Host string ClientId string Scope string pkce *models.PKCEChallenge interval time.Duration } func (c *Oauth2PKCEDeviceClient) Authorize(ctx context.Context) (*models.DeviceAuthorizationResponse, error) { challenge, err := models.NewPKCEChallenge() if err != nil { return nil, err } c.pkce = challenge values, err := query.Values(models.AuthorizationRequest{ Challenge: c.pkce.Challenge(), ChallengeMethod: models.ChallengeS256, ClientId: c.ClientId, Scope: c.Scope, }) if err != nil { return nil, err } url := fmt.Sprintf("https://%s/auth/device", c.Host) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(values.Encode())) if err != nil { return nil, err } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") res, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer res.Body.Close() if res.StatusCode != 200 { var resError models.Oauth2Error if err := json.NewDecoder(res.Body).Decode(&resError); err != nil { return nil, err } return nil, resError } var resp models.DeviceAuthorizationResponse if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { return nil, err } c.interval = time.Duration(resp.Interval) * time.Second if c.interval == 0 { c.interval = 5 * time.Second } return &resp, nil } func (c *Oauth2PKCEDeviceClient) fetchToken(ctx context.Context, deviceCode string) (*models.AccessTokenResponse, error) { values, err := query.Values(models.DeviceAccessTokenRequest{ GrantType: models.DEVICE_CODE_GRANT_TYPE, DeviceCode: deviceCode, ClientId: c.ClientId, CodeVerifier: c.pkce.Verifier, }) if err != nil { return nil, err } url := fmt.Sprintf("https://%s/auth/token", c.Host) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(values.Encode())) if err != nil { return nil, err } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") res, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer res.Body.Close() if res.StatusCode != 200 { var resError models.Oauth2Error if err := json.NewDecoder(res.Body).Decode(&resError); err != nil { return nil, err } if resError.Type == models.ErrSlowDown { c.interval += 5 * time.Second } return nil, resError } var resp models.AccessTokenResponse if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { return nil, err } return &resp, nil } func (c *Oauth2PKCEDeviceClient) AwaitToken(ctx context.Context, deviceCode string) (*models.AccessTokenResponse, error) { t := time.NewTicker(c.interval) defer t.Stop() res, err := c.fetchToken(ctx, deviceCode) if err == nil { return res, nil } else if e, ok := err.(models.Oauth2Error); ok { if e.Type == models.ErrSlowDown { t.Reset(c.interval) } else if e.Type != models.ErrAuthorizationPending { return nil, err } } else { return nil, err } for { select { case <-t.C: res, err := c.fetchToken(ctx, deviceCode) if err == nil { return res, nil } else if e, ok := err.(models.Oauth2Error); ok { if e.Type == models.ErrSlowDown { t.Reset(c.interval) } else if e.Type != models.ErrAuthorizationPending { return nil, err } } else { return nil, err } case <-ctx.Done(): return nil, fmt.Errorf("Context has expired") } } }