package main import ( "net/url" "time" "github.com/pkg/errors" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" ) // TODO // validate amr claim contains requested acr values (selective_mfa will be just mfa) // validate acr claim is the same as requested acr_values // // acr_values can be mfa or selective_mfa (mfa only for external users) // mfa amr values: // pas - password // otp - OTP code // u2f - U2F code // mfa - multi-factor // hrd - hardware OTP device used // sft - software OTP device used type Claims struct { Nonce string `json:"nonce,omitempty"` jwt.Claims } func (c *Claims) Age() int64 { return int64(time.Since(c.IssuedAt.Time()).Minutes()) } type JWSValidationContext struct { KeyFetcher JWKSFetcher Issuer string ClientId *url.URL ClockSkew time.Duration MaxLiftetime time.Duration } type JWSValidator interface { Validate(string, string) (*Claims, error) } type jwsValidator struct { algorithms *stringSet jwks JWKSFetcher issuer string clientId *url.URL clockSkew time.Duration maxLifetime time.Duration } func NewJWSValidator(c *JWSValidationContext) JWSValidator { return &jwsValidator{ algorithms: NewStringSet("PS256", "PS385", "PS512"), jwks: c.KeyFetcher, issuer: c.Issuer, clientId: c.ClientId, clockSkew: c.ClockSkew, maxLifetime: c.MaxLiftetime, } } func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { parsed, err := jwt.ParseSigned(j) if err != nil { return nil, errors.WithStack(err) } if err := v.validateHeaders(parsed.Headers); err != nil { return nil, errors.WithStack(err) } kid := parsed.Headers[0].KeyID key, err := v.jwks.GetKey(kid) if err != nil { return nil, errors.WithStack(err) } claims, err := v.validateClaims(parsed, key) if err != nil { return nil, errors.WithStack(err) } if err := v.validateNonce(nonce, claims.Nonce); err != nil { return nil, errors.WithStack(err) } return claims, nil } func (v *jwsValidator) validateHeaders(h []jose.Header) error { if len(h) != 1 { return errors.Errorf("Invalid signature count") } if !v.algorithms.Contains(h[0].Algorithm) { return errors.Errorf("Invalid signature algorithm") } if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { return errors.Errorf("Invalid token type") } return nil } func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) { claims := &Claims{} if err := j.Claims(k, claims); err != nil { return nil, errors.WithStack(err) } exp := jwt.Expected{ Issuer: v.issuer, Audience: jwt.Audience{v.clientId.String()}, Time: time.Now(), } if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { return nil, errors.WithStack(err) } if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { return nil, errors.Errorf("Token exceeded max lifetime") } return claims, nil } func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { if token_nonce != Sha256Hex(nonce) { return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce)) } return nil }