summaryrefslogtreecommitdiff
path: root/jws_validator.go
blob: 0b2467f2868aad1d68a15514931db2a31f729afc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package main

import (
	"github.com/pkg/errors"
	"gopkg.in/square/go-jose.v2"
	"gopkg.in/square/go-jose.v2/jwt"
	"net/url"
	"time"
)

// TODO
// validate amr claim contains requested acr values (selective_mfa will be just mfa)
// validate acr claim is the same as requested acr_values
//
// acr_values can be mfa or selective_mfa (mfa only for external users)
// mfa amr values:
// pas - password
// otp - OTP code
// u2f - U2F code
// mfa - multi-factor
// hrd - hardware OTP device used
// sft - software OTP device used

type Claims struct {
	Nonce string `json:"nonce,omitempty"`
	jwt.Claims
}

type JWSValidationContext struct {
	KeyFetcher   JWKSFetcher
	Issuer       string
	ClientId     *url.URL
	ClockSkew    time.Duration
	MaxLiftetime time.Duration
}

type JWSValidator interface {
	Validate(string, string) (*Claims, error)
}

type jwsValidator struct {
	algorithms  *stringSet
	jwks        JWKSFetcher
	issuer      string
	clientId    *url.URL
	clockSkew   time.Duration
	maxLifetime time.Duration
}

func NewJWSValidator(c *JWSValidationContext) JWSValidator {
	return &jwsValidator{
		algorithms:  NewStringSet("PS256", "PS385", "PS512"),
		jwks:        c.KeyFetcher,
		issuer:      c.Issuer,
		clientId:    c.ClientId,
		clockSkew:   c.ClockSkew,
		maxLifetime: c.MaxLiftetime,
	}
}

func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) {
	parsed, err := jwt.ParseSigned(j)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	if err := v.validateHeaders(parsed.Headers); err != nil {
		return nil, errors.WithStack(err)
	}

	kid := parsed.Headers[0].KeyID
	key, err := v.jwks.GetKey(kid)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	claims, err := v.validateClaims(parsed, key)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	if err := v.validateNonce(nonce, claims.Nonce); err != nil {
		return nil, errors.WithStack(err)
	}

	return claims, nil
}

func (v *jwsValidator) validateHeaders(h []jose.Header) error {
	if len(h) != 1 {
		return errors.Errorf("Invalid signature count")
	}

	if !v.algorithms.Contains(h[0].Algorithm) {
		return errors.Errorf("Invalid signature algorithm")
	}

	if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" {
		return errors.Errorf("Invalid token type")
	}

	return nil
}

func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) {
	claims := &Claims{}
	if err := j.Claims(k, claims); err != nil {
		return nil, errors.WithStack(err)
	}

	exp := jwt.Expected{
		Issuer:   v.issuer,
		Audience: jwt.Audience{v.clientId.String()},
		Time:     time.Now(),
	}

	if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil {
		return nil, errors.WithStack(err)
	}

	if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) {
		return nil, errors.Errorf("Token exceeded max lifetime")
	}

	return claims, nil
}

func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error {
	if token_nonce != Sha256Hex(nonce) {
		return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce))
	}

	return nil
}