summaryrefslogtreecommitdiff
path: root/jws_validator.go
diff options
context:
space:
mode:
Diffstat (limited to 'jws_validator.go')
-rw-r--r--jws_validator.go110
1 files changed, 70 insertions, 40 deletions
diff --git a/jws_validator.go b/jws_validator.go
index e77c026..0b2467f 100644
--- a/jws_validator.go
+++ b/jws_validator.go
@@ -1,103 +1,133 @@
1package main 1package main
2 2
3import ( 3import (
4 "crypto/sha256" 4 "github.com/pkg/errors"
5 "encoding/hex"
6 "fmt"
7 "gopkg.in/square/go-jose.v2" 5 "gopkg.in/square/go-jose.v2"
8 "gopkg.in/square/go-jose.v2/jwt" 6 "gopkg.in/square/go-jose.v2/jwt"
7 "net/url"
9 "time" 8 "time"
10) 9)
11 10
11// TODO
12// validate amr claim contains requested acr values (selective_mfa will be just mfa)
13// validate acr claim is the same as requested acr_values
14//
15// acr_values can be mfa or selective_mfa (mfa only for external users)
16// mfa amr values:
17// pas - password
18// otp - OTP code
19// u2f - U2F code
20// mfa - multi-factor
21// hrd - hardware OTP device used
22// sft - software OTP device used
23
12type Claims struct { 24type Claims struct {
13 Nonce string `json:"nonce,omitempty"` 25 Nonce string `json:"nonce,omitempty"`
14 jwt.Claims 26 jwt.Claims
15} 27}
16 28
29type JWSValidationContext struct {
30 KeyFetcher JWKSFetcher
31 Issuer string
32 ClientId *url.URL
33 ClockSkew time.Duration
34 MaxLiftetime time.Duration
35}
36
17type JWSValidator interface { 37type JWSValidator interface {
18 Validate(string, string) (*Claims, error) 38 Validate(string, string) (*Claims, error)
19} 39}
20 40
21type jwsValidator struct { 41type jwsValidator struct {
22 algorithms *stringSet 42 algorithms *stringSet
23 jwks map[string]jose.JSONWebKey 43 jwks JWKSFetcher
24 issuer string 44 issuer string
25 clientID string 45 clientId *url.URL
26 clockSkew time.Duration 46 clockSkew time.Duration
27 maxLifetime time.Duration 47 maxLifetime time.Duration
28} 48}
29 49
30// TODO 50func NewJWSValidator(c *JWSValidationContext) JWSValidator {
31// validate amr claim contains requested acr values (selective_mfa will be just mfa)
32// validate acr claim is the same as requested acr_values
33func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator {
34 return &jwsValidator{ 51 return &jwsValidator{
35 algorithms: NewStringSet("PS256", "PS385", "PS512"), 52 algorithms: NewStringSet("PS256", "PS385", "PS512"),
36 jwks: jwks, 53 jwks: c.KeyFetcher,
37 issuer: issuer, 54 issuer: c.Issuer,
38 clientID: client_id, 55 clientId: c.ClientId,
39 clockSkew: skew, 56 clockSkew: c.ClockSkew,
40 maxLifetime: max_life, 57 maxLifetime: c.MaxLiftetime,
41 } 58 }
42} 59}
43 60
44func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { 61func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) {
45 parsed_jwt, err := jwt.ParseSigned(j) 62 parsed, err := jwt.ParseSigned(j)
63 if err != nil {
64 return nil, errors.WithStack(err)
65 }
66
67 if err := v.validateHeaders(parsed.Headers); err != nil {
68 return nil, errors.WithStack(err)
69 }
70
71 kid := parsed.Headers[0].KeyID
72 key, err := v.jwks.GetKey(kid)
73 if err != nil {
74 return nil, errors.WithStack(err)
75 }
76
77 claims, err := v.validateClaims(parsed, key)
46 if err != nil { 78 if err != nil {
47 return nil, err 79 return nil, errors.WithStack(err)
48 } 80 }
49 81
50 if len(parsed_jwt.Headers) != 1 { 82 if err := v.validateNonce(nonce, claims.Nonce); err != nil {
51 return nil, fmt.Errorf("Invalid signature count") 83 return nil, errors.WithStack(err)
52 } 84 }
53 85
54 head := parsed_jwt.Headers[0] 86 return claims, nil
87}
55 88
56 if !v.algorithms.Contains(head.Algorithm) { 89func (v *jwsValidator) validateHeaders(h []jose.Header) error {
57 return nil, fmt.Errorf("Invalid signature algorithm") 90 if len(h) != 1 {
91 return errors.Errorf("Invalid signature count")
58 } 92 }
59 93
60 if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { 94 if !v.algorithms.Contains(h[0].Algorithm) {
61 return nil, fmt.Errorf("Invalid token type") 95 return errors.Errorf("Invalid signature algorithm")
62 } 96 }
63 97
64 key, ok := v.jwks[head.KeyID] 98 if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" {
65 if !ok { 99 return errors.Errorf("Invalid token type")
66 return nil, fmt.Errorf("No key found for key id")
67 } 100 }
68 101
102 return nil
103}
104
105func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) {
69 claims := &Claims{} 106 claims := &Claims{}
70 if err = parsed_jwt.Claims(key, claims); err != nil { 107 if err := j.Claims(k, claims); err != nil {
71 return nil, err 108 return nil, errors.WithStack(err)
72 } 109 }
73 110
74 exp := jwt.Expected{ 111 exp := jwt.Expected{
75 Issuer: v.issuer, 112 Issuer: v.issuer,
76 Audience: jwt.Audience{v.clientID}, 113 Audience: jwt.Audience{v.clientId.String()},
77 Time: time.Now(), 114 Time: time.Now(),
78 } 115 }
79 116
80 if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { 117 if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil {
81 return nil, err 118 return nil, errors.WithStack(err)
82 } 119 }
83 120
84 if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { 121 if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) {
85 return nil, fmt.Errorf("Token exceeded max lifetime") 122 return nil, errors.Errorf("Token exceeded max lifetime")
86 }
87
88 if err = v.validateNonce(nonce, claims.Nonce); err != nil {
89 return nil, err
90 } 123 }
91 124
92 return claims, nil 125 return claims, nil
93} 126}
94 127
95func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { 128func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error {
96 s256 := sha256.New() 129 if token_nonce != Sha256Hex(nonce) {
97 s256.Write([]byte(nonce)) 130 return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce))
98 hashed_nonce := hex.EncodeToString(s256.Sum(nil))
99 if token_nonce != hashed_nonce {
100 return fmt.Errorf("Invalid nonce")
101 } 131 }
102 132
103 return nil 133 return nil