From b7867d9cf5b0dd175b8167a552b830ebfe47d0ed Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Tue, 5 Sep 2017 03:52:50 +0000 Subject: Finish JWS and Cert validation --- cautious_http_client.go | 13 ++--- jws_validator.go | 104 ++++++++++++++++++++++++++++++++++ key_validator.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 118 +++++++++++++-------------------------- oidc_proxy | Bin 0 -> 6495319 bytes util.go | 43 +++++++++++++++ 6 files changed, 335 insertions(+), 87 deletions(-) create mode 100644 jws_validator.go create mode 100644 key_validator.go create mode 100755 oidc_proxy create mode 100644 util.go diff --git a/cautious_http_client.go b/cautious_http_client.go index 66179f2..2f33ae0 100644 --- a/cautious_http_client.go +++ b/cautious_http_client.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "net" "net/http" "net/url" @@ -28,9 +29,9 @@ func NewCautiousHTTPClient() CautiousHTTPClient { }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 3 * time.Second, + TLSHandshakeTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second, - ResponseHeaderTimeout: 5 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, MaxResponseHeaderBytes: 500000, // .5 MB } @@ -49,11 +50,9 @@ func (c *cautiousHttpClient) Get(gurl string) (*http.Response, error) { } // TODO - /* - if u.Scheme != "https" { - return nil, fmt.Errorf("URL for GET must be secure") - } - */ + if u.Scheme != "https" && false { + return nil, fmt.Errorf("URL for GET must be secure") + } r, err := c.client.Get(u.String()) if err != nil { diff --git a/jws_validator.go b/jws_validator.go new file mode 100644 index 0000000..e77c026 --- /dev/null +++ b/jws_validator.go @@ -0,0 +1,104 @@ +package main + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + "time" +) + +type Claims struct { + Nonce string `json:"nonce,omitempty"` + jwt.Claims +} + +type JWSValidator interface { + Validate(string, string) (*Claims, error) +} + +type jwsValidator struct { + algorithms *stringSet + jwks map[string]jose.JSONWebKey + issuer string + clientID string + clockSkew time.Duration + maxLifetime time.Duration +} + +// TODO +// validate amr claim contains requested acr values (selective_mfa will be just mfa) +// validate acr claim is the same as requested acr_values +func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator { + return &jwsValidator{ + algorithms: NewStringSet("PS256", "PS385", "PS512"), + jwks: jwks, + issuer: issuer, + clientID: client_id, + clockSkew: skew, + maxLifetime: max_life, + } +} + +func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { + parsed_jwt, err := jwt.ParseSigned(j) + if err != nil { + return nil, err + } + + if len(parsed_jwt.Headers) != 1 { + return nil, fmt.Errorf("Invalid signature count") + } + + head := parsed_jwt.Headers[0] + + if !v.algorithms.Contains(head.Algorithm) { + return nil, fmt.Errorf("Invalid signature algorithm") + } + + if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { + return nil, fmt.Errorf("Invalid token type") + } + + key, ok := v.jwks[head.KeyID] + if !ok { + return nil, fmt.Errorf("No key found for key id") + } + + claims := &Claims{} + if err = parsed_jwt.Claims(key, claims); err != nil { + return nil, err + } + + exp := jwt.Expected{ + Issuer: v.issuer, + Audience: jwt.Audience{v.clientID}, + Time: time.Now(), + } + + if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { + return nil, err + } + + if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { + return nil, fmt.Errorf("Token exceeded max lifetime") + } + + if err = v.validateNonce(nonce, claims.Nonce); err != nil { + return nil, err + } + + return claims, nil +} + +func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { + s256 := sha256.New() + s256.Write([]byte(nonce)) + hashed_nonce := hex.EncodeToString(s256.Sum(nil)) + if token_nonce != hashed_nonce { + return fmt.Errorf("Invalid nonce") + } + + return nil +} diff --git a/key_validator.go b/key_validator.go new file mode 100644 index 0000000..fe6eb7b --- /dev/null +++ b/key_validator.go @@ -0,0 +1,144 @@ +package main + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "gopkg.in/square/go-jose.v2" + "io/ioutil" +) + +type KeyValidator interface { + Validate(jose.JSONWebKey) error + LoadRootPEM(string) error +} + +type keyValidator struct { + pkiSubject string + algorithms *stringSet + roots *x509.CertPool +} + +func NewKeyValidator(subject string) KeyValidator { + return &keyValidator{ + pkiSubject: subject, + algorithms: NewStringSet("PS256", "PS385", "PS512"), + roots: x509.NewCertPool(), + } +} + +func (v *keyValidator) LoadRootPEM(filename string) error { + pem_data, err := ioutil.ReadFile(filename) + if err != nil { + return err + } + + pem_block, _ := pem.Decode(pem_data) + if pem_block == nil { + return fmt.Errorf("PEM decode failed") + } + + cert, err := x509.ParseCertificate(pem_block.Bytes) + if err != nil { + return err + } + + v.roots.AddCert(cert) + + return nil +} + +func (v *keyValidator) Validate(key jose.JSONWebKey) error { + pk, ok := key.Key.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("Key type is not RSA") + } + + if !v.algorithms.Contains(key.Algorithm) { + return fmt.Errorf("Key algorithm is not supported") + } + + cert := key.Certificates[0] + cpk, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return fmt.Errorf("Public key is not RSA") + } + + if cpk.N.BitLen() < 2048 { + return fmt.Errorf("Key length less than 2048 bits") + } + + if cert.KeyUsage&x509.KeyUsageDigitalSignature != 1 { + return fmt.Errorf("Certificate not valid for digital signatures") + } + + err := v.validateCertificateChain(key.Certificates) + if err != nil { + return err + } + + err = v.validateCertificateCRL(cert) + if err != nil { + return err + } + + err = v.validatePublicKeyInCertificate(pk, cpk) + if err != nil { + return err + } + + return nil +} + +// TODO +// Fetch CRL from distrubtion point in cert +// Validate CRL signed by trusted CA +// Validate cert not in CRL +func (v *keyValidator) validateCertificateCRL(cert *x509.Certificate) error { + return nil +} + +func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error { + vo := x509.VerifyOptions{ + Roots: v.roots, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + } + + if len(chain) > 1 { + ip := x509.NewCertPool() + for _, i := range chain[1:] { + ip.AddCert(i) + } + + vo.Intermediates = ip + } + + chains, err := chain[0].Verify(vo) + if err != nil { + return err + } + + if len(chains) <= 0 { + return fmt.Errorf("No valid certificate chains found") + } + + if chain[0].Subject.CommonName != v.pkiSubject { + return fmt.Errorf("Invalid certificate subject name") + } + + return nil +} + +// validate first item of x5c matches n and e +func (v *keyValidator) validatePublicKeyInCertificate(pk *rsa.PublicKey, cpk *rsa.PublicKey) error { + if cpk.E != pk.E { + return fmt.Errorf("E in key and E in cert do not match") + } + + if pk.N.Cmp(cpk.N) != 0 { + return fmt.Errorf("N in key and N in cert do not match") + } + + return nil +} diff --git a/main.go b/main.go index 44501c0..965e72c 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "net/http/httputil" "net/url" "strings" + "time" ) const ( @@ -19,6 +20,8 @@ const ( RFP_COOKIE_NAME string = "sso_rfp" ) +// TODO: Enable https checks in HTTP client + // acr_values can be mfa or selective_mfa (mfa only for external users) // mfa amr values: // pas - password @@ -34,6 +37,9 @@ type ProxyConfig struct { UpstreamURL string ListenOn string TrustedCACert string + PKISubject string // TODO: Should be same as IDP w/out scheme and port + ClockSkew time.Duration + MaxLiftetime time.Duration IsOptional bool RequestMFA bool AllowedMFAMethods []string // An OR set @@ -45,7 +51,7 @@ type IdPConfig struct { AuthorizationEndpoint string `json:"authorization_endpoint"` Issuer string `json:"issuer"` JwksUri string `json:"jwks_uri"` - SupportedGrantTypes []string `json:"grant_types_supported"` + GrantTypes []string `json:"grant_types_supported"` IdTokenSigningAlgs []string `json:"id_token_signing_alg_values_supported"` ResponseModes []string `json:"response_modes_supported"` ResponseTypes []string `json:"response_types_supported"` @@ -53,6 +59,7 @@ type IdPConfig struct { SubjectTypes []string `json:"subject_types_supported"` } +// TODO: Optimization to fetch only if expired (per http headers) func FetchIdPConfig(h CautiousHTTPClient, idp_url string) (*IdPConfig, error) { u, err := url.Parse(idp_url) if err != nil { @@ -70,7 +77,7 @@ func FetchIdPConfig(h CautiousHTTPClient, idp_url string) (*IdPConfig, error) { } // TODO: Optimization to fetch only if expired (per http headers) -func FetchJWKS(h CautiousHTTPClient, jwks_url string) (map[string]jose.JSONWebKey, error) { +func FetchJWKS(h CautiousHTTPClient, jwks_url string, val KeyValidator) (map[string]jose.JSONWebKey, error) { var jwks jose.JSONWebKeySet err := h.GetJSON(jwks_url, &jwks) if err != nil { @@ -80,20 +87,15 @@ func FetchJWKS(h CautiousHTTPClient, jwks_url string) (map[string]jose.JSONWebKe keys := make(map[string]jose.JSONWebKey, len(jwks.Keys)) for _, k := range jwks.Keys { - keys[k.KeyID] = k + err = val.Validate(k) + if err == nil { + keys[k.KeyID] = k + } } return keys, nil } -func URLMustParse(u string) *url.URL { - o, err := url.Parse(u) - if err != nil { - panic(err) - } - return o -} - func GenerateNonce() (string, error) { nonce := make([]byte, NONCE_SIZE) n, err := rand.Read(nonce) @@ -103,10 +105,6 @@ func GenerateNonce() (string, error) { return hex.EncodeToString(nonce), nil } -func CompareUpper(lhs, rhs string) bool { - return strings.ToUpper(lhs) == strings.ToUpper(rhs) -} - // TODO // Cookie rules // Secure @@ -116,28 +114,6 @@ func CompareUpper(lhs, rhs string) bool { func SetCookie() { } -// TODO -// Fetch (connect timeout 1s, read timeout 30s, read size 1M) -func DownloadCertificate() { -} - -// TODO -// Fetch (connect timeout 1s, read timeout 30s, read size 1M) -func DownloadCRL() { -} - -// TODO -// Cert validation -// Validate cert not in CRL -// Validate cert chains to trusted CA cert (ship with proxy) -// Validate CRL signed by trusted CA -// Cert "Subject CN " must match exactly PKI setting (ex: foo-pki.foo.com) -// Current time bust be within "Validity Not Before" and "Validity Not After" in cert +- 5 minutes -// Cert Key length >= 2048 -// Certificate usage must include "digitalSignature" -func ValidateCertificate() { -} - // TODO func MakeClientID(r *http.Request) string { if strings.Contains(r.Host, ":") { @@ -167,42 +143,6 @@ func SetTokenCookieAndRedirect(w http.ResponseWriter, r *http.Request, token str } // TODO -// Occasionally refresh IDP config (per HTTP caching headers) -// -// Fetch ${IDP_HOST}/.well-known/openid-configuration -// - validate certificate chains to a trusted root -// - validate scopes_supported contains "openid" -// - validate response_types_supported contains "id_token" -// - validate grant_types_supported contains "implicit" -// - validate id_token_signing_alg_values_supported contains a supported signing type (see below) -// - Cache authorization_endpoint for redirecting users -// -// Fetch jwks_uri endpoint -// - Build key map indexed by kid for all keys that are suppored by our rules -// - kty == RSA -// - alg header must be one of [PS256, PS385, PS512] -// - pem decode x5c and validate the certificate chain as below -// - validate first item of x5c matches n and e -func RefreshIDPConfig() { -} - -// TODO -// If x5u exists in header -// Fetch cert from x5u URL -// Get CRL from cert, fetch (connect timeout 1s, read timeout 30s, read size 1M) -// -// exp claim has passed +- 5 minutes -// iat claim is greater than 24 hours +- 5 minutes -// aud claim is exact match for client_id -// iss claim is exact match for idp (ex: foo.example.com) -// if other aud claims validate that they are known -// nonce in JWT must be SHA256 of rfp cookie value -// Validate cert -// alg jwt header must be one of [PS256, PS385, PS512] -// typ jwt header must be JWS -// validate jwt signature -// validate amr claim contains requested acr values (selective_mfa will be just mfa) -// validate acr claim is the same as requested acr_values func ValidateJWT(jwt, rfp string) bool { return true } @@ -318,33 +258,51 @@ func LoginController(w http.ResponseWriter, r *http.Request) { // user will be redirected back to the main page for the site (/) func parseConfig() *ProxyConfig { return &ProxyConfig{ - IDProviderURL: "", - ClientID: "", + IDProviderURL: "http://mcrute-virt:9993", + ClientID: "test.crute.me:443", UpstreamURL: "http://localhost:9991/", ListenOn: ":9992", - TrustedCACert: "", + TrustedCACert: "/home/mcrute/oidc_project/test_ca/ca_cert.pem", IsOptional: false, + PKISubject: "Crute OpenID Signing 1", + MaxLiftetime: 24 * time.Hour, + ClockSkew: 5 * time.Minute, } } func main() { + cfg := parseConfig() h := NewCautiousHTTPClient() - idpc, err := FetchIdPConfig(h, "http://mcrute-virt:9993") + v := NewKeyValidator(cfg.PKISubject) + v.LoadRootPEM(cfg.TrustedCACert) + + idpc, err := FetchIdPConfig(h, cfg.IDProviderURL) if err != nil { fmt.Printf("%s\n", err) return } - jwks, err := FetchJWKS(h, idpc.JwksUri) + jwks, err := FetchJWKS(h, idpc.JwksUri, v) if err != nil { fmt.Printf("%s\n", err) return } - fmt.Printf("%+v\n", jwks) + + jv := NewJWSValidator(jwks, idpc.Issuer, cfg.ClientID, cfg.ClockSkew, cfg.MaxLiftetime) + + nonce := "ofspmfjuvoswhhde" + raw_jwt := "eyJ0eXAiOiJKV1MiLCJhbGciOiJQUzI1NiIsImtpZCI6IjEifQ.eyJub25jZSI6IjM0MjlhMjAyYzU4ZDkyYjQwNjNjOWM4MWM2MjQyNGRlNzBkMmIzZDQ4MmVlNDFhOTdjYmNhZjEwZDk5MWFiOTMiLCJpc3MiOiJpZHAuY3J1dGUubWU6NDQzIiwiaWF0IjoxNTA0NTc2Mzc0LCJuYmYiOjE1MDQ1NzYzNzQsImV4cCI6MTUwNDY2Mjc3NCwic3ViIjoibWNydXRlIiwiYXVkIjoidGVzdC5jcnV0ZS5tZTo0NDMifQ.iizlNfY1Vg7d-XRmgyYuhpNkNrOGaT9OOgO0HdjBozOWMvKzBTtATbIfoWOrNH6DiFY1as8uy3I1Pxnkrb8Ti8_cLDQeLxOv9klAbnebeuPI_wtZ0iwSUnSWaYzN6I6sqcEjHX3fibFvAQhO5dNDzSwONjw4AvcdpZKh579FO1sAvIw-1DmMyPSUun7rbC0Kf1Jtdlr3q7tOp3wdI_erkstxCNPwyuv7X1J7uetsu0BeJS25C2DxeB03BPEIUoo_C1xvcqikfSLLpoFcyToYiS-R9o-WpRjGid_yug65J5ALn2aM3vhe9rRbydKVm_omGL8-Etj06zbqM0Y6OrJUgA" + claims, err := jv.Validate(raw_jwt, nonce) + if err != nil { + fmt.Printf("Error validating: %s\n", err) + return + } + + fmt.Printf("Valid JWT for: %+v\n", claims.Subject) + return - cfg := parseConfig() cfg.reverseProxy = httputil.NewSingleHostReverseProxy(URLMustParse(cfg.UpstreamURL)) if cfg.IsOptional { diff --git a/oidc_proxy b/oidc_proxy new file mode 100755 index 0000000..e5df267 Binary files /dev/null and b/oidc_proxy differ diff --git a/util.go b/util.go new file mode 100644 index 0000000..10709e2 --- /dev/null +++ b/util.go @@ -0,0 +1,43 @@ +package main + +import ( + "net/url" + "strings" +) + +type stringSet struct { + values map[string]bool +} + +func NewStringSet(values ...string) *stringSet { + s := &stringSet{ + values: make(map[string]bool, len(values)), + } + + for _, v := range values { + s.Add(v) + } + + return s +} + +func (s *stringSet) Add(v string) { + s.values[v] = true +} + +func (s *stringSet) Contains(k string) bool { + _, ok := s.values[k] + return ok +} + +func URLMustParse(u string) *url.URL { + o, err := url.Parse(u) + if err != nil { + panic(err) + } + return o +} + +func CompareUpper(lhs, rhs string) bool { + return strings.ToUpper(lhs) == strings.ToUpper(rhs) +} -- cgit v1.2.3