1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
|
package main
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"time"
)
type Claims struct {
Nonce string `json:"nonce,omitempty"`
jwt.Claims
}
type JWSValidator interface {
Validate(string, string) (*Claims, error)
}
type jwsValidator struct {
algorithms *stringSet
jwks map[string]jose.JSONWebKey
issuer string
clientID string
clockSkew time.Duration
maxLifetime time.Duration
}
// TODO
// validate amr claim contains requested acr values (selective_mfa will be just mfa)
// validate acr claim is the same as requested acr_values
func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator {
return &jwsValidator{
algorithms: NewStringSet("PS256", "PS385", "PS512"),
jwks: jwks,
issuer: issuer,
clientID: client_id,
clockSkew: skew,
maxLifetime: max_life,
}
}
func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) {
parsed_jwt, err := jwt.ParseSigned(j)
if err != nil {
return nil, err
}
if len(parsed_jwt.Headers) != 1 {
return nil, fmt.Errorf("Invalid signature count")
}
head := parsed_jwt.Headers[0]
if !v.algorithms.Contains(head.Algorithm) {
return nil, fmt.Errorf("Invalid signature algorithm")
}
if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" {
return nil, fmt.Errorf("Invalid token type")
}
key, ok := v.jwks[head.KeyID]
if !ok {
return nil, fmt.Errorf("No key found for key id")
}
claims := &Claims{}
if err = parsed_jwt.Claims(key, claims); err != nil {
return nil, err
}
exp := jwt.Expected{
Issuer: v.issuer,
Audience: jwt.Audience{v.clientID},
Time: time.Now(),
}
if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil {
return nil, err
}
if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) {
return nil, fmt.Errorf("Token exceeded max lifetime")
}
if err = v.validateNonce(nonce, claims.Nonce); err != nil {
return nil, err
}
return claims, nil
}
func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error {
s256 := sha256.New()
s256.Write([]byte(nonce))
hashed_nonce := hex.EncodeToString(s256.Sum(nil))
if token_nonce != hashed_nonce {
return fmt.Errorf("Invalid nonce")
}
return nil
}
|