diff options
Diffstat (limited to 'jws_validator.go')
-rw-r--r-- | jws_validator.go | 110 |
1 files changed, 70 insertions, 40 deletions
diff --git a/jws_validator.go b/jws_validator.go index e77c026..0b2467f 100644 --- a/jws_validator.go +++ b/jws_validator.go | |||
@@ -1,103 +1,133 @@ | |||
1 | package main | 1 | package main |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "crypto/sha256" | 4 | "github.com/pkg/errors" |
5 | "encoding/hex" | ||
6 | "fmt" | ||
7 | "gopkg.in/square/go-jose.v2" | 5 | "gopkg.in/square/go-jose.v2" |
8 | "gopkg.in/square/go-jose.v2/jwt" | 6 | "gopkg.in/square/go-jose.v2/jwt" |
7 | "net/url" | ||
9 | "time" | 8 | "time" |
10 | ) | 9 | ) |
11 | 10 | ||
11 | // TODO | ||
12 | // validate amr claim contains requested acr values (selective_mfa will be just mfa) | ||
13 | // validate acr claim is the same as requested acr_values | ||
14 | // | ||
15 | // acr_values can be mfa or selective_mfa (mfa only for external users) | ||
16 | // mfa amr values: | ||
17 | // pas - password | ||
18 | // otp - OTP code | ||
19 | // u2f - U2F code | ||
20 | // mfa - multi-factor | ||
21 | // hrd - hardware OTP device used | ||
22 | // sft - software OTP device used | ||
23 | |||
12 | type Claims struct { | 24 | type Claims struct { |
13 | Nonce string `json:"nonce,omitempty"` | 25 | Nonce string `json:"nonce,omitempty"` |
14 | jwt.Claims | 26 | jwt.Claims |
15 | } | 27 | } |
16 | 28 | ||
29 | type JWSValidationContext struct { | ||
30 | KeyFetcher JWKSFetcher | ||
31 | Issuer string | ||
32 | ClientId *url.URL | ||
33 | ClockSkew time.Duration | ||
34 | MaxLiftetime time.Duration | ||
35 | } | ||
36 | |||
17 | type JWSValidator interface { | 37 | type JWSValidator interface { |
18 | Validate(string, string) (*Claims, error) | 38 | Validate(string, string) (*Claims, error) |
19 | } | 39 | } |
20 | 40 | ||
21 | type jwsValidator struct { | 41 | type jwsValidator struct { |
22 | algorithms *stringSet | 42 | algorithms *stringSet |
23 | jwks map[string]jose.JSONWebKey | 43 | jwks JWKSFetcher |
24 | issuer string | 44 | issuer string |
25 | clientID string | 45 | clientId *url.URL |
26 | clockSkew time.Duration | 46 | clockSkew time.Duration |
27 | maxLifetime time.Duration | 47 | maxLifetime time.Duration |
28 | } | 48 | } |
29 | 49 | ||
30 | // TODO | 50 | func NewJWSValidator(c *JWSValidationContext) JWSValidator { |
31 | // validate amr claim contains requested acr values (selective_mfa will be just mfa) | ||
32 | // validate acr claim is the same as requested acr_values | ||
33 | func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator { | ||
34 | return &jwsValidator{ | 51 | return &jwsValidator{ |
35 | algorithms: NewStringSet("PS256", "PS385", "PS512"), | 52 | algorithms: NewStringSet("PS256", "PS385", "PS512"), |
36 | jwks: jwks, | 53 | jwks: c.KeyFetcher, |
37 | issuer: issuer, | 54 | issuer: c.Issuer, |
38 | clientID: client_id, | 55 | clientId: c.ClientId, |
39 | clockSkew: skew, | 56 | clockSkew: c.ClockSkew, |
40 | maxLifetime: max_life, | 57 | maxLifetime: c.MaxLiftetime, |
41 | } | 58 | } |
42 | } | 59 | } |
43 | 60 | ||
44 | func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { | 61 | func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { |
45 | parsed_jwt, err := jwt.ParseSigned(j) | 62 | parsed, err := jwt.ParseSigned(j) |
63 | if err != nil { | ||
64 | return nil, errors.WithStack(err) | ||
65 | } | ||
66 | |||
67 | if err := v.validateHeaders(parsed.Headers); err != nil { | ||
68 | return nil, errors.WithStack(err) | ||
69 | } | ||
70 | |||
71 | kid := parsed.Headers[0].KeyID | ||
72 | key, err := v.jwks.GetKey(kid) | ||
73 | if err != nil { | ||
74 | return nil, errors.WithStack(err) | ||
75 | } | ||
76 | |||
77 | claims, err := v.validateClaims(parsed, key) | ||
46 | if err != nil { | 78 | if err != nil { |
47 | return nil, err | 79 | return nil, errors.WithStack(err) |
48 | } | 80 | } |
49 | 81 | ||
50 | if len(parsed_jwt.Headers) != 1 { | 82 | if err := v.validateNonce(nonce, claims.Nonce); err != nil { |
51 | return nil, fmt.Errorf("Invalid signature count") | 83 | return nil, errors.WithStack(err) |
52 | } | 84 | } |
53 | 85 | ||
54 | head := parsed_jwt.Headers[0] | 86 | return claims, nil |
87 | } | ||
55 | 88 | ||
56 | if !v.algorithms.Contains(head.Algorithm) { | 89 | func (v *jwsValidator) validateHeaders(h []jose.Header) error { |
57 | return nil, fmt.Errorf("Invalid signature algorithm") | 90 | if len(h) != 1 { |
91 | return errors.Errorf("Invalid signature count") | ||
58 | } | 92 | } |
59 | 93 | ||
60 | if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { | 94 | if !v.algorithms.Contains(h[0].Algorithm) { |
61 | return nil, fmt.Errorf("Invalid token type") | 95 | return errors.Errorf("Invalid signature algorithm") |
62 | } | 96 | } |
63 | 97 | ||
64 | key, ok := v.jwks[head.KeyID] | 98 | if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { |
65 | if !ok { | 99 | return errors.Errorf("Invalid token type") |
66 | return nil, fmt.Errorf("No key found for key id") | ||
67 | } | 100 | } |
68 | 101 | ||
102 | return nil | ||
103 | } | ||
104 | |||
105 | func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) { | ||
69 | claims := &Claims{} | 106 | claims := &Claims{} |
70 | if err = parsed_jwt.Claims(key, claims); err != nil { | 107 | if err := j.Claims(k, claims); err != nil { |
71 | return nil, err | 108 | return nil, errors.WithStack(err) |
72 | } | 109 | } |
73 | 110 | ||
74 | exp := jwt.Expected{ | 111 | exp := jwt.Expected{ |
75 | Issuer: v.issuer, | 112 | Issuer: v.issuer, |
76 | Audience: jwt.Audience{v.clientID}, | 113 | Audience: jwt.Audience{v.clientId.String()}, |
77 | Time: time.Now(), | 114 | Time: time.Now(), |
78 | } | 115 | } |
79 | 116 | ||
80 | if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { | 117 | if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { |
81 | return nil, err | 118 | return nil, errors.WithStack(err) |
82 | } | 119 | } |
83 | 120 | ||
84 | if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { | 121 | if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { |
85 | return nil, fmt.Errorf("Token exceeded max lifetime") | 122 | return nil, errors.Errorf("Token exceeded max lifetime") |
86 | } | ||
87 | |||
88 | if err = v.validateNonce(nonce, claims.Nonce); err != nil { | ||
89 | return nil, err | ||
90 | } | 123 | } |
91 | 124 | ||
92 | return claims, nil | 125 | return claims, nil |
93 | } | 126 | } |
94 | 127 | ||
95 | func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { | 128 | func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { |
96 | s256 := sha256.New() | 129 | if token_nonce != Sha256Hex(nonce) { |
97 | s256.Write([]byte(nonce)) | 130 | return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce)) |
98 | hashed_nonce := hex.EncodeToString(s256.Sum(nil)) | ||
99 | if token_nonce != hashed_nonce { | ||
100 | return fmt.Errorf("Invalid nonce") | ||
101 | } | 131 | } |
102 | 132 | ||
103 | return nil | 133 | return nil |