package main import ( "crypto/sha256" "encoding/hex" "fmt" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" "time" ) type Claims struct { Nonce string `json:"nonce,omitempty"` jwt.Claims } type JWSValidator interface { Validate(string, string) (*Claims, error) } type jwsValidator struct { algorithms *stringSet jwks map[string]jose.JSONWebKey issuer string clientID string clockSkew time.Duration maxLifetime time.Duration } // TODO // validate amr claim contains requested acr values (selective_mfa will be just mfa) // validate acr claim is the same as requested acr_values func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator { return &jwsValidator{ algorithms: NewStringSet("PS256", "PS385", "PS512"), jwks: jwks, issuer: issuer, clientID: client_id, clockSkew: skew, maxLifetime: max_life, } } func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { parsed_jwt, err := jwt.ParseSigned(j) if err != nil { return nil, err } if len(parsed_jwt.Headers) != 1 { return nil, fmt.Errorf("Invalid signature count") } head := parsed_jwt.Headers[0] if !v.algorithms.Contains(head.Algorithm) { return nil, fmt.Errorf("Invalid signature algorithm") } if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { return nil, fmt.Errorf("Invalid token type") } key, ok := v.jwks[head.KeyID] if !ok { return nil, fmt.Errorf("No key found for key id") } claims := &Claims{} if err = parsed_jwt.Claims(key, claims); err != nil { return nil, err } exp := jwt.Expected{ Issuer: v.issuer, Audience: jwt.Audience{v.clientID}, Time: time.Now(), } if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { return nil, err } if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { return nil, fmt.Errorf("Token exceeded max lifetime") } if err = v.validateNonce(nonce, claims.Nonce); err != nil { return nil, err } return claims, nil } func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { s256 := sha256.New() s256.Write([]byte(nonce)) hashed_nonce := hex.EncodeToString(s256.Sum(nil)) if token_nonce != hashed_nonce { return fmt.Errorf("Invalid nonce") } return nil }