summaryrefslogtreecommitdiff
path: root/jws_validator.go
diff options
context:
space:
mode:
Diffstat (limited to 'jws_validator.go')
-rw-r--r--jws_validator.go104
1 files changed, 104 insertions, 0 deletions
diff --git a/jws_validator.go b/jws_validator.go
new file mode 100644
index 0000000..e77c026
--- /dev/null
+++ b/jws_validator.go
@@ -0,0 +1,104 @@
1package main
2
3import (
4 "crypto/sha256"
5 "encoding/hex"
6 "fmt"
7 "gopkg.in/square/go-jose.v2"
8 "gopkg.in/square/go-jose.v2/jwt"
9 "time"
10)
11
12type Claims struct {
13 Nonce string `json:"nonce,omitempty"`
14 jwt.Claims
15}
16
17type JWSValidator interface {
18 Validate(string, string) (*Claims, error)
19}
20
21type jwsValidator struct {
22 algorithms *stringSet
23 jwks map[string]jose.JSONWebKey
24 issuer string
25 clientID string
26 clockSkew time.Duration
27 maxLifetime time.Duration
28}
29
30// TODO
31// validate amr claim contains requested acr values (selective_mfa will be just mfa)
32// validate acr claim is the same as requested acr_values
33func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator {
34 return &jwsValidator{
35 algorithms: NewStringSet("PS256", "PS385", "PS512"),
36 jwks: jwks,
37 issuer: issuer,
38 clientID: client_id,
39 clockSkew: skew,
40 maxLifetime: max_life,
41 }
42}
43
44func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) {
45 parsed_jwt, err := jwt.ParseSigned(j)
46 if err != nil {
47 return nil, err
48 }
49
50 if len(parsed_jwt.Headers) != 1 {
51 return nil, fmt.Errorf("Invalid signature count")
52 }
53
54 head := parsed_jwt.Headers[0]
55
56 if !v.algorithms.Contains(head.Algorithm) {
57 return nil, fmt.Errorf("Invalid signature algorithm")
58 }
59
60 if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" {
61 return nil, fmt.Errorf("Invalid token type")
62 }
63
64 key, ok := v.jwks[head.KeyID]
65 if !ok {
66 return nil, fmt.Errorf("No key found for key id")
67 }
68
69 claims := &Claims{}
70 if err = parsed_jwt.Claims(key, claims); err != nil {
71 return nil, err
72 }
73
74 exp := jwt.Expected{
75 Issuer: v.issuer,
76 Audience: jwt.Audience{v.clientID},
77 Time: time.Now(),
78 }
79
80 if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil {
81 return nil, err
82 }
83
84 if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) {
85 return nil, fmt.Errorf("Token exceeded max lifetime")
86 }
87
88 if err = v.validateNonce(nonce, claims.Nonce); err != nil {
89 return nil, err
90 }
91
92 return claims, nil
93}
94
95func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error {
96 s256 := sha256.New()
97 s256.Write([]byte(nonce))
98 hashed_nonce := hex.EncodeToString(s256.Sum(nil))
99 if token_nonce != hashed_nonce {
100 return fmt.Errorf("Invalid nonce")
101 }
102
103 return nil
104}