diff options
Diffstat (limited to 'jws_validator.go')
-rw-r--r-- | jws_validator.go | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/jws_validator.go b/jws_validator.go new file mode 100644 index 0000000..e77c026 --- /dev/null +++ b/jws_validator.go | |||
@@ -0,0 +1,104 @@ | |||
1 | package main | ||
2 | |||
3 | import ( | ||
4 | "crypto/sha256" | ||
5 | "encoding/hex" | ||
6 | "fmt" | ||
7 | "gopkg.in/square/go-jose.v2" | ||
8 | "gopkg.in/square/go-jose.v2/jwt" | ||
9 | "time" | ||
10 | ) | ||
11 | |||
12 | type Claims struct { | ||
13 | Nonce string `json:"nonce,omitempty"` | ||
14 | jwt.Claims | ||
15 | } | ||
16 | |||
17 | type JWSValidator interface { | ||
18 | Validate(string, string) (*Claims, error) | ||
19 | } | ||
20 | |||
21 | type jwsValidator struct { | ||
22 | algorithms *stringSet | ||
23 | jwks map[string]jose.JSONWebKey | ||
24 | issuer string | ||
25 | clientID string | ||
26 | clockSkew time.Duration | ||
27 | maxLifetime time.Duration | ||
28 | } | ||
29 | |||
30 | // TODO | ||
31 | // validate amr claim contains requested acr values (selective_mfa will be just mfa) | ||
32 | // validate acr claim is the same as requested acr_values | ||
33 | func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator { | ||
34 | return &jwsValidator{ | ||
35 | algorithms: NewStringSet("PS256", "PS385", "PS512"), | ||
36 | jwks: jwks, | ||
37 | issuer: issuer, | ||
38 | clientID: client_id, | ||
39 | clockSkew: skew, | ||
40 | maxLifetime: max_life, | ||
41 | } | ||
42 | } | ||
43 | |||
44 | func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { | ||
45 | parsed_jwt, err := jwt.ParseSigned(j) | ||
46 | if err != nil { | ||
47 | return nil, err | ||
48 | } | ||
49 | |||
50 | if len(parsed_jwt.Headers) != 1 { | ||
51 | return nil, fmt.Errorf("Invalid signature count") | ||
52 | } | ||
53 | |||
54 | head := parsed_jwt.Headers[0] | ||
55 | |||
56 | if !v.algorithms.Contains(head.Algorithm) { | ||
57 | return nil, fmt.Errorf("Invalid signature algorithm") | ||
58 | } | ||
59 | |||
60 | if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { | ||
61 | return nil, fmt.Errorf("Invalid token type") | ||
62 | } | ||
63 | |||
64 | key, ok := v.jwks[head.KeyID] | ||
65 | if !ok { | ||
66 | return nil, fmt.Errorf("No key found for key id") | ||
67 | } | ||
68 | |||
69 | claims := &Claims{} | ||
70 | if err = parsed_jwt.Claims(key, claims); err != nil { | ||
71 | return nil, err | ||
72 | } | ||
73 | |||
74 | exp := jwt.Expected{ | ||
75 | Issuer: v.issuer, | ||
76 | Audience: jwt.Audience{v.clientID}, | ||
77 | Time: time.Now(), | ||
78 | } | ||
79 | |||
80 | if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { | ||
81 | return nil, err | ||
82 | } | ||
83 | |||
84 | if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { | ||
85 | return nil, fmt.Errorf("Token exceeded max lifetime") | ||
86 | } | ||
87 | |||
88 | if err = v.validateNonce(nonce, claims.Nonce); err != nil { | ||
89 | return nil, err | ||
90 | } | ||
91 | |||
92 | return claims, nil | ||
93 | } | ||
94 | |||
95 | func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { | ||
96 | s256 := sha256.New() | ||
97 | s256.Write([]byte(nonce)) | ||
98 | hashed_nonce := hex.EncodeToString(s256.Sum(nil)) | ||
99 | if token_nonce != hashed_nonce { | ||
100 | return fmt.Errorf("Invalid nonce") | ||
101 | } | ||
102 | |||
103 | return nil | ||
104 | } | ||