From 9f7861ffe1397da514606b189f5b3e383f4e7ed7 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Tue, 19 Sep 2017 04:39:36 +0000 Subject: Finish out most of the proxy functionality --- cautious_http_client.go | 87 ++++++++-- jwks_fetcher.go | 118 ++++++++++++++ jws_validator.go | 110 ++++++++----- key_validator.go | 36 +++-- main.go | 409 ++++++++++++++++++++++++++++-------------------- oidc_proxy | Bin 6495319 -> 7509902 bytes util.go | 18 +++ 7 files changed, 534 insertions(+), 244 deletions(-) create mode 100644 jwks_fetcher.go diff --git a/cautious_http_client.go b/cautious_http_client.go index 2f33ae0..34b736f 100644 --- a/cautious_http_client.go +++ b/cautious_http_client.go @@ -2,24 +2,29 @@ package main import ( "encoding/json" - "fmt" + "github.com/lox/httpcache" + "github.com/pkg/errors" "net" "net/http" "net/url" + "strings" "time" ) type CautiousHTTPClient interface { Get(string) (*http.Response, error) GetJSON(string, interface{}) error + GetJSONExpires(string, interface{}) (time.Duration, error) } type cautiousHttpClient struct { - client *http.Client + allowHttp bool + client *http.Client } -func NewCautiousHTTPClient() CautiousHTTPClient { - // May Need: TLSClientConfig *tls.Config +// allowHttp is UNSAFE and technically validates the spec but it does make it +// easier to work in dev so leaving it in for now +func NewCautiousHTTPClient(allowHttp bool) (CautiousHTTPClient, error) { CautiousTransport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ @@ -36,44 +41,100 @@ func NewCautiousHTTPClient() CautiousHTTPClient { } return &cautiousHttpClient{ + allowHttp: allowHttp, client: &http.Client{ Transport: CautiousTransport, Timeout: 30 * time.Second, }, - } + }, nil } func (c *cautiousHttpClient) Get(gurl string) (*http.Response, error) { u, err := url.Parse(gurl) if err != nil { - return nil, err + return nil, errors.WithStack(err) } - // TODO - if u.Scheme != "https" && false { - return nil, fmt.Errorf("URL for GET must be secure") + if u.Scheme != "https" && !c.allowHttp { + return nil, errors.Errorf("URL for GET must be secure") } r, err := c.client.Get(u.String()) if err != nil { - return nil, err + return nil, errors.WithStack(err) } r.Body = http.MaxBytesReader(nil, r.Body, 1000000) - return r, err + + return r, nil } func (c *cautiousHttpClient) GetJSON(url string, rv interface{}) error { r, err := c.Get(url) if err != nil { - return err + return errors.WithStack(err) } defer r.Body.Close() d := json.NewDecoder(r.Body) err = d.Decode(rv) if err != nil { - return err + return errors.WithStack(err) } return nil } + +func (c *cautiousHttpClient) GetJSONExpires(url string, rv interface{}) (time.Duration, error) { + r, err := c.Get(url) + if err != nil { + return time.Duration(0), errors.WithStack(err) + } + defer r.Body.Close() + + res := httpcache.NewResource(r.StatusCode, nil, r.Header) + + d := json.NewDecoder(r.Body) + err = d.Decode(rv) + if err != nil { + return time.Duration(0), errors.WithStack(err) + } + + return refreshAfter(res), nil +} + +type JSONURL struct { + *url.URL +} + +func (u *JSONURL) AsURL() *url.URL { + return u.URL +} + +func (u *JSONURL) UnmarshalJSON(data []byte) error { + d := strings.Trim(string(data), "\"") + pu, err := url.Parse(d) + if err != nil { + return errors.WithStack(err) + } + + u.URL = pu + return nil +} + +func refreshAfter(res *httpcache.Resource) time.Duration { + maxAge, err := res.MaxAge(false) + if err != nil { + return time.Duration(0) + } + + age, err := res.Age() + if err != nil { + return time.Duration(0) + } + + if hFresh := res.HeuristicFreshness(); hFresh > maxAge { + maxAge = hFresh + } + + return maxAge - age +} diff --git a/jwks_fetcher.go b/jwks_fetcher.go new file mode 100644 index 0000000..9925430 --- /dev/null +++ b/jwks_fetcher.go @@ -0,0 +1,118 @@ +package main + +import ( + "github.com/pkg/errors" + "gopkg.in/square/go-jose.v2" + "log" + "net/url" + "time" +) + +const ( + REQUEST_BUFFER_SIZE = 10 + KEY_MAP_INITIAL_SIZE = 5 + DEFAULT_REFRESH_INTERVAL = 15 * time.Minute + MIN_REFRESH_INTERVAL = 1 * time.Minute +) + +type KeyRequest struct { + KeyId string + Response chan *jose.JSONWebKey +} + +type JWKSFetcher interface { + Run() + Fetch() error + GetKey(string) (*jose.JSONWebKey, error) + Done() +} + +type jwksFetcher struct { + keyMap map[string]jose.JSONWebKey + httpClient CautiousHTTPClient + validator KeyValidator + fetchTimer *time.Timer + url *url.URL + requests chan *KeyRequest + done chan bool +} + +func NewJWKSFetcher(h CautiousHTTPClient, url *url.URL, issuer string, root string) JWKSFetcher { + val := NewKeyValidator(HostFromURL(issuer)) + val.LoadRootPEM(root) + + return &jwksFetcher{ + httpClient: h, + validator: val, + url: url, + fetchTimer: time.NewTimer(DEFAULT_REFRESH_INTERVAL), + requests: make(chan *KeyRequest, REQUEST_BUFFER_SIZE), + keyMap: make(map[string]jose.JSONWebKey, KEY_MAP_INITIAL_SIZE), + done: make(chan bool), + } +} + +func (f *jwksFetcher) Fetch() error { + var jwks jose.JSONWebKeySet + timeout, err := f.httpClient.GetJSONExpires(f.url.String(), &jwks) + if err != nil { + return errors.WithStack(err) + } + + for _, k := range jwks.Keys { + err = f.validator.Validate(k) + if err == nil { + f.keyMap[k.KeyID] = k + } else { + log.Printf("Rejecting key %q because %q", k.KeyID, err) + } + } + + if timeout < MIN_REFRESH_INTERVAL { + timeout = MIN_REFRESH_INTERVAL + } + + success := f.fetchTimer.Reset(timeout) + if !success { + f.fetchTimer = time.NewTimer(timeout) + } + + return nil +} + +func (f *jwksFetcher) Run() { + for { + select { + // Incoming request for a key, return key or nil in no key + case r := <-f.requests: + if v, ok := f.keyMap[r.KeyId]; ok { + r.Response <- &v + } else { + r.Response <- nil + } + case <-f.fetchTimer.C: + f.Fetch() + case <-f.done: + return + } + } +} + +func (f *jwksFetcher) Done() { + f.done <- true +} + +func (f *jwksFetcher) GetKey(kid string) (*jose.JSONWebKey, error) { + r := &KeyRequest{ + KeyId: kid, + Response: make(chan *jose.JSONWebKey), + } + + f.requests <- r + + if res := <-r.Response; res == nil { + return nil, errors.Errorf("Key not found for ID") + } else { + return res, nil + } +} diff --git a/jws_validator.go b/jws_validator.go index e77c026..0b2467f 100644 --- a/jws_validator.go +++ b/jws_validator.go @@ -1,103 +1,133 @@ package main import ( - "crypto/sha256" - "encoding/hex" - "fmt" + "github.com/pkg/errors" "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" + "net/url" "time" ) +// TODO +// validate amr claim contains requested acr values (selective_mfa will be just mfa) +// validate acr claim is the same as requested acr_values +// +// acr_values can be mfa or selective_mfa (mfa only for external users) +// mfa amr values: +// pas - password +// otp - OTP code +// u2f - U2F code +// mfa - multi-factor +// hrd - hardware OTP device used +// sft - software OTP device used + type Claims struct { Nonce string `json:"nonce,omitempty"` jwt.Claims } +type JWSValidationContext struct { + KeyFetcher JWKSFetcher + Issuer string + ClientId *url.URL + ClockSkew time.Duration + MaxLiftetime time.Duration +} + type JWSValidator interface { Validate(string, string) (*Claims, error) } type jwsValidator struct { algorithms *stringSet - jwks map[string]jose.JSONWebKey + jwks JWKSFetcher issuer string - clientID string + clientId *url.URL clockSkew time.Duration maxLifetime time.Duration } -// TODO -// validate amr claim contains requested acr values (selective_mfa will be just mfa) -// validate acr claim is the same as requested acr_values -func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator { +func NewJWSValidator(c *JWSValidationContext) JWSValidator { return &jwsValidator{ algorithms: NewStringSet("PS256", "PS385", "PS512"), - jwks: jwks, - issuer: issuer, - clientID: client_id, - clockSkew: skew, - maxLifetime: max_life, + jwks: c.KeyFetcher, + issuer: c.Issuer, + clientId: c.ClientId, + clockSkew: c.ClockSkew, + maxLifetime: c.MaxLiftetime, } } func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { - parsed_jwt, err := jwt.ParseSigned(j) + parsed, err := jwt.ParseSigned(j) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := v.validateHeaders(parsed.Headers); err != nil { + return nil, errors.WithStack(err) + } + + kid := parsed.Headers[0].KeyID + key, err := v.jwks.GetKey(kid) + if err != nil { + return nil, errors.WithStack(err) + } + + claims, err := v.validateClaims(parsed, key) if err != nil { - return nil, err + return nil, errors.WithStack(err) } - if len(parsed_jwt.Headers) != 1 { - return nil, fmt.Errorf("Invalid signature count") + if err := v.validateNonce(nonce, claims.Nonce); err != nil { + return nil, errors.WithStack(err) } - head := parsed_jwt.Headers[0] + return claims, nil +} - if !v.algorithms.Contains(head.Algorithm) { - return nil, fmt.Errorf("Invalid signature algorithm") +func (v *jwsValidator) validateHeaders(h []jose.Header) error { + if len(h) != 1 { + return errors.Errorf("Invalid signature count") } - if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { - return nil, fmt.Errorf("Invalid token type") + if !v.algorithms.Contains(h[0].Algorithm) { + return errors.Errorf("Invalid signature algorithm") } - key, ok := v.jwks[head.KeyID] - if !ok { - return nil, fmt.Errorf("No key found for key id") + if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { + return errors.Errorf("Invalid token type") } + return nil +} + +func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) { claims := &Claims{} - if err = parsed_jwt.Claims(key, claims); err != nil { - return nil, err + if err := j.Claims(k, claims); err != nil { + return nil, errors.WithStack(err) } exp := jwt.Expected{ Issuer: v.issuer, - Audience: jwt.Audience{v.clientID}, + Audience: jwt.Audience{v.clientId.String()}, Time: time.Now(), } if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { - return nil, err + return nil, errors.WithStack(err) } if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { - return nil, fmt.Errorf("Token exceeded max lifetime") - } - - if err = v.validateNonce(nonce, claims.Nonce); err != nil { - return nil, err + return nil, errors.Errorf("Token exceeded max lifetime") } return claims, nil } func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { - s256 := sha256.New() - s256.Write([]byte(nonce)) - hashed_nonce := hex.EncodeToString(s256.Sum(nil)) - if token_nonce != hashed_nonce { - return fmt.Errorf("Invalid nonce") + if token_nonce != Sha256Hex(nonce) { + return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce)) } return nil diff --git a/key_validator.go b/key_validator.go index fe6eb7b..062d78c 100644 --- a/key_validator.go +++ b/key_validator.go @@ -4,11 +4,13 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "fmt" + "github.com/pkg/errors" "gopkg.in/square/go-jose.v2" "io/ioutil" ) +// TODO: CRL validation + type KeyValidator interface { Validate(jose.JSONWebKey) error LoadRootPEM(string) error @@ -31,17 +33,17 @@ func NewKeyValidator(subject string) KeyValidator { func (v *keyValidator) LoadRootPEM(filename string) error { pem_data, err := ioutil.ReadFile(filename) if err != nil { - return err + return errors.WithStack(err) } pem_block, _ := pem.Decode(pem_data) if pem_block == nil { - return fmt.Errorf("PEM decode failed") + return errors.Errorf("PEM decode failed") } cert, err := x509.ParseCertificate(pem_block.Bytes) if err != nil { - return err + return errors.WithStack(err) } v.roots.AddCert(cert) @@ -52,40 +54,40 @@ func (v *keyValidator) LoadRootPEM(filename string) error { func (v *keyValidator) Validate(key jose.JSONWebKey) error { pk, ok := key.Key.(*rsa.PublicKey) if !ok { - return fmt.Errorf("Key type is not RSA") + return errors.Errorf("Key type is not RSA") } if !v.algorithms.Contains(key.Algorithm) { - return fmt.Errorf("Key algorithm is not supported") + return errors.Errorf("Key algorithm is not supported") } cert := key.Certificates[0] cpk, ok := cert.PublicKey.(*rsa.PublicKey) if !ok { - return fmt.Errorf("Public key is not RSA") + return errors.Errorf("Public key is not RSA") } if cpk.N.BitLen() < 2048 { - return fmt.Errorf("Key length less than 2048 bits") + return errors.Errorf("Key length less than 2048 bits") } if cert.KeyUsage&x509.KeyUsageDigitalSignature != 1 { - return fmt.Errorf("Certificate not valid for digital signatures") + return errors.Errorf("Certificate not valid for digital signatures") } err := v.validateCertificateChain(key.Certificates) if err != nil { - return err + return errors.WithStack(err) } err = v.validateCertificateCRL(cert) if err != nil { - return err + return errors.WithStack(err) } err = v.validatePublicKeyInCertificate(pk, cpk) if err != nil { - return err + return errors.WithStack(err) } return nil @@ -116,15 +118,15 @@ func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error chains, err := chain[0].Verify(vo) if err != nil { - return err + return errors.WithStack(err) } if len(chains) <= 0 { - return fmt.Errorf("No valid certificate chains found") + return errors.Errorf("No valid certificate chains found") } if chain[0].Subject.CommonName != v.pkiSubject { - return fmt.Errorf("Invalid certificate subject name") + return errors.Errorf("Invalid certificate subject name") } return nil @@ -133,11 +135,11 @@ func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error // validate first item of x5c matches n and e func (v *keyValidator) validatePublicKeyInCertificate(pk *rsa.PublicKey, cpk *rsa.PublicKey) error { if cpk.E != pk.E { - return fmt.Errorf("E in key and E in cert do not match") + return errors.Errorf("E in key and E in cert do not match") } if pk.N.Cmp(cpk.N) != 0 { - return fmt.Errorf("N in key and N in cert do not match") + return errors.Errorf("N in key and N in cert do not match") } return nil diff --git a/main.go b/main.go index 965e72c..805c40d 100644 --- a/main.go +++ b/main.go @@ -4,53 +4,52 @@ import ( "context" "crypto/rand" "encoding/hex" - "fmt" - "gopkg.in/square/go-jose.v2" - "log" + "flag" + "github.com/golang/glog" + "github.com/gorilla/handlers" + "github.com/pkg/errors" "net/http" "net/http/httputil" "net/url" + "os" + "strconv" "strings" "time" ) const ( - NONCE_SIZE int = 16 - TOKEN_COOKIE_NAME string = "sso_token" - RFP_COOKIE_NAME string = "sso_rfp" + NONCE_SIZE = 16 + TOKEN_COOKIE_NAME = "sso_token" + RFP_COOKIE_NAME = "sso_rfp" + DEFAULT_CLOCK_SKEW = 5 * time.Minute + DEFAULT_MAX_LIFETIME = 24 * time.Hour + DEFAULT_COOKIE_EXP = 48 * time.Hour ) -// TODO: Enable https checks in HTTP client - -// acr_values can be mfa or selective_mfa (mfa only for external users) -// mfa amr values: -// pas - password -// otp - OTP code -// u2f - U2F code -// mfa - multi-factor -// hrd - hardware OTP device used -// sft - software OTP device used +// TODO: MFA support type ProxyConfig struct { - IDProviderURL string - ClientID string - UpstreamURL string - ListenOn string - TrustedCACert string - PKISubject string // TODO: Should be same as IDP w/out scheme and port - ClockSkew time.Duration - MaxLiftetime time.Duration - IsOptional bool - RequestMFA bool - AllowedMFAMethods []string // An OR set - RequiredMFAMethods []string // An AND set - reverseProxy *httputil.ReverseProxy + IdProviderURL *url.URL + IdProviderAuthEndpoint *url.URL + ClientId *url.URL + UpstreamURL string + ListenOn string + TrustedCACert string + ClockSkew time.Duration + MaxLiftetime time.Duration + IsOptional bool + IsBootstrap bool + RequestMFA bool + AllowedMFAMethods []string // An OR set + RequiredMFAMethods []string // An AND set + reverseProxy *httputil.ReverseProxy + jwsValidator JWSValidator } type IdPConfig struct { - AuthorizationEndpoint string `json:"authorization_endpoint"` + AuthorizationEndpoint *JSONURL `json:"authorization_endpoint"` + JwksUri *JSONURL `json:"jwks_uri"` Issuer string `json:"issuer"` - JwksUri string `json:"jwks_uri"` GrantTypes []string `json:"grant_types_supported"` IdTokenSigningAlgs []string `json:"id_token_signing_alg_values_supported"` ResponseModes []string `json:"response_modes_supported"` @@ -59,258 +58,311 @@ type IdPConfig struct { SubjectTypes []string `json:"subject_types_supported"` } -// TODO: Optimization to fetch only if expired (per http headers) -func FetchIdPConfig(h CautiousHTTPClient, idp_url string) (*IdPConfig, error) { - u, err := url.Parse(idp_url) - if err != nil { - return nil, err - } +func FetchIdPConfig(h CautiousHTTPClient, u *url.URL) (*IdPConfig, error) { + u = URLMustParse(u.String()) u.Path = "/.well-known/openid-configuration" var idpc IdPConfig - err = h.GetJSON(u.String(), &idpc) + err := h.GetJSON(u.String(), &idpc) if err != nil { - return nil, err + return nil, errors.WithStack(err) } return &idpc, nil } -// TODO: Optimization to fetch only if expired (per http headers) -func FetchJWKS(h CautiousHTTPClient, jwks_url string, val KeyValidator) (map[string]jose.JSONWebKey, error) { - var jwks jose.JSONWebKeySet - err := h.GetJSON(jwks_url, &jwks) - if err != nil { - return nil, err - } - - keys := make(map[string]jose.JSONWebKey, len(jwks.Keys)) - - for _, k := range jwks.Keys { - err = val.Validate(k) - if err == nil { - keys[k.KeyID] = k - } - } - - return keys, nil -} - func GenerateNonce() (string, error) { nonce := make([]byte, NONCE_SIZE) n, err := rand.Read(nonce) if n != NONCE_SIZE || err != nil { - return "", err + return "", errors.WithStack(err) } return hex.EncodeToString(nonce), nil } -// TODO -// Cookie rules -// Secure -// HttpOnly -// Path to / -// Expires to iat in JWT -func SetCookie() { +func SetSecureCookie(w http.ResponseWriter, name string, value string, exp time.Duration) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: value, + Expires: time.Now().Add(exp), + HttpOnly: true, + Secure: true, + Path: "/", + }) } -// TODO -func MakeClientID(r *http.Request) string { - if strings.Contains(r.Host, ":") { - return r.Host - } - return "" +func ExpireCookie(w http.ResponseWriter, name string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + Expires: time.Now().Add(-1 * time.Hour), + HttpOnly: true, + Secure: true, + Path: "/", + MaxAge: 0, + }) } -// TODO -func RedirectToIDP(w http.ResponseWriter, r *http.Request) { - nonce, _ := GenerateNonce() - _ = nonce - nonceh := "" // SHA256 nonce +func RedirectToIdP(w http.ResponseWriter, r *http.Request, path string) { + ctx := r.Context().Value("ProxyConfig").(*ProxyConfig) - // Set nonce cookie + nonce, err := GenerateNonce() + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + SetSecureCookie(w, RFP_COOKIE_NAME, nonce, DEFAULT_COOKIE_EXP) + + rt := "" + rp := r.URL.Query().Get("redirect_uri") + if rp != "" { + rt = rp + } else { + ru := &url.URL{ + Scheme: "https", + Host: r.Host, + Path: path, + } + rt = ru.String() + } req := url.Values{} - req.Add("client_id", "") // fqdn + : + port - req.Add("nonce", nonceh) - req.Add("redirect_uri", "") // Requested URL + req.Add("client_id", ctx.ClientId.String()) + req.Add("nonce", Sha256Hex(nonce)) + req.Add("redirect_uri", rt) req.Add("scope", "openid") req.Add("response_type", "id_token") + + u := URLMustParse(ctx.IdProviderAuthEndpoint.String()) + u.RawQuery = req.Encode() + + http.Redirect(w, r, u.String(), http.StatusFound) } -// TODO: Remove id_token from URL, set cookie and redirect user to requested URL func SetTokenCookieAndRedirect(w http.ResponseWriter, r *http.Request, token string) { -} + SetSecureCookie(w, TOKEN_COOKIE_NAME, token, DEFAULT_COOKIE_EXP) -// TODO -func ValidateJWT(jwt, rfp string) bool { - return true -} + q := r.URL.Query() + q.Del("id_token") + + u := URLMustParse(r.URL.String()) + u.RawQuery = q.Encode() -// TODO -func GetJWTSubject(jwt string) string { - return "" + http.Redirect(w, r, u.String(), http.StatusFound) } func RequestHasForwardedUser(w http.ResponseWriter, r *http.Request) bool { - if _, ok := r.Header["X-Forwarded-User"]; ok { - log.Printf("ERROR: Request contains X-Forwarded-For header") - http.Error(w, "Bad Request", http.StatusBadRequest) - return true - } else { - return false - } + _, ok := r.Header["X-Forwarded-User"] + return ok } func RequestIsOverSecureChannel(w http.ResponseWriter, r *http.Request) bool { https, ok := r.Header["X-Forwarded-Proto"] if !ok || len(https) != 1 { - log.Printf("ERROR: Request does not contain X-Forwarded-Proto header") - http.Error(w, "Bad Request", http.StatusBadRequest) + glog.V(1).Infoln("Request does not contain X-Forwarded-Proto header") return false } if !CompareUpper(https[0], "HTTPS") { - log.Printf("ERROR: Request is not over HTTPS") - http.Error(w, "Bad Request", http.StatusBadRequest) + glog.V(1).Infoln("Request is not over HTTPS") return false } return true } -// TODO -// - Validate Hostname header == known hostname func AuthProxyController(w http.ResponseWriter, r *http.Request) { - proxy := r.Context().Value("ProxyConfig").(*ProxyConfig).reverseProxy + ctx := r.Context().Value("ProxyConfig").(*ProxyConfig) // Order matters in these checks! if RequestHasForwardedUser(w, r) { + glog.Errorln("Request already has X-Forwarded-User") + http.Error(w, "Bad Request", http.StatusBadRequest) return } if !RequestIsOverSecureChannel(w, r) { + http.Error(w, "Bad Request", http.StatusBadRequest) return } if CompareUpper(r.Method, "OPTIONS") { - proxy.ServeHTTP(w, r) + ctx.reverseProxy.ServeHTTP(w, r) return } rfpc, err := r.Cookie(RFP_COOKIE_NAME) if err != nil { - log.Printf("ERROR: No rfp cookie") - RedirectToIDP(w, r) - return + glog.V(1).Infoln("No rfp cookie") + if ctx.IsOptional { + ctx.reverseProxy.ServeHTTP(w, r) + return + } else { + RedirectToIdP(w, r, r.URL.Path) + return + } } - token := r.URL.Query().Get("id_token") - if token != "" && ValidateJWT(rfpc.Value, token) { - SetTokenCookieAndRedirect(w, r, token) - return + if token := r.URL.Query().Get("id_token"); token != "" { + if _, err := ctx.jwsValidator.Validate(token, rfpc.Value); err == nil { + SetTokenCookieAndRedirect(w, r, token) + return + } else { + glog.V(1).Infof("Querystring id_token invalid: %s", err) + } } tokenc, err := r.Cookie(TOKEN_COOKIE_NAME) if err != nil { - log.Printf("ERROR: No token cookie") - RedirectToIDP(w, r) - return + glog.V(1).Infoln("No token cookie") + if ctx.IsOptional { + ctx.reverseProxy.ServeHTTP(w, r) + return + } else { + RedirectToIdP(w, r, r.URL.Path) + return + } } - if !ValidateJWT(tokenc.Value, rfpc.Value) { - log.Printf("ERROR: Token is invalid") - RedirectToIDP(w, r) - return + claims, err := ctx.jwsValidator.Validate(tokenc.Value, rfpc.Value) + if err != nil { + glog.Errorln("Token is invalid", err) + if ctx.IsOptional { + ctx.reverseProxy.ServeHTTP(w, r) + return + } else { + RedirectToIdP(w, r, r.URL.Path) + return + } } - r.Header["X-Forwarded-User"] = []string{GetJWTSubject(tokenc.Value)} + r.Header["X-Forwarded-User"] = []string{claims.Subject} + r.Header["X-Forwarded-Token-Expires"] = []string{strconv.FormatInt(int64(claims.Expiry), 10)} + + age := time.Since(claims.IssuedAt.Time()).Minutes() + r.Header["X-Forwarded-Token-Age"] = []string{strconv.FormatInt(int64(age), 10)} - proxy.ServeHTTP(w, r) + ctx.reverseProxy.ServeHTTP(w, r) } -// Remove token and rfp cookies and redirect user to root of domain func LogoutController(w http.ResponseWriter, r *http.Request) { - http.SetCookie(w, &http.Cookie{ - Name: TOKEN_COOKIE_NAME, - Value: "", - MaxAge: 0, - }) - - http.SetCookie(w, &http.Cookie{ - Name: TOKEN_COOKIE_NAME, - Value: "", - MaxAge: 0, - }) - + ExpireCookie(w, RFP_COOKIE_NAME) + ExpireCookie(w, TOKEN_COOKIE_NAME) http.Redirect(w, r, "/", http.StatusFound) } -// TODO -func LoginController(w http.ResponseWriter, r *http.Request) { -} - -// TODO // Optional login allows for applications that can operate in anonymous mode or // authenticated mode. When in anonmyous mode the request is proxied through // without an X-Forwarded-User header. Upstream servers should either expose or // map a URL for /.oidc/login to allow users to login. On successful login the // user will be redirected back to the main page for the site (/) -func parseConfig() *ProxyConfig { - return &ProxyConfig{ - IDProviderURL: "http://mcrute-virt:9993", - ClientID: "test.crute.me:443", - UpstreamURL: "http://localhost:9991/", - ListenOn: ":9992", - TrustedCACert: "/home/mcrute/oidc_project/test_ca/ca_cert.pem", - IsOptional: false, - PKISubject: "Crute OpenID Signing 1", - MaxLiftetime: 24 * time.Hour, - ClockSkew: 5 * time.Minute, +func parseConfig() (*ProxyConfig, error) { + c := &ProxyConfig{} + + idpu := flag.String("idp", "", "URL for ID provider") + mfam := flag.String("allow-mfa-methods", "", "Comma seperated list of allowed mfa methods") + rmfa := flag.String("require-mfa-methods", "", "Comma seperated list of required mfa methods") + cids := flag.String("client-id", "", "Client ID for proxy with IdP") + + flag.BoolVar(&c.IsOptional, "optional", false, "Allow proxying of unauthenticated calls") + flag.BoolVar(&c.IsBootstrap, "bootstrap", false, "Allow running a proxy for the IdP itself") + flag.BoolVar(&c.RequestMFA, "mfa", false, "Request user MFA authentication from IdP") + + flag.DurationVar(&c.MaxLiftetime, "max-lifetime", DEFAULT_MAX_LIFETIME, "Maximum allowed time from token issuance") + flag.DurationVar(&c.ClockSkew, "clock-skew", DEFAULT_CLOCK_SKEW, "Allowable IdP clock skew relative to proxy") + + flag.StringVar(&c.UpstreamURL, "upstream", "", "URL of upstream service for which to proxy") + flag.StringVar(&c.ListenOn, "listen", ":9992", "Optional port and ip on which to listen") + flag.StringVar(&c.TrustedCACert, "ca", "", "Path to trusted CA certificate") + + flag.Parse() + + c.AllowedMFAMethods = strings.Split(*mfam, ",") + c.RequiredMFAMethods = strings.Split(*rmfa, ",") + + if c.IsBootstrap { + c.IsOptional = true + } + + if _, err := os.Stat(c.TrustedCACert); os.IsNotExist(err) { + return nil, errors.Errorf("CA certificate does not exist") + } + + if cids == nil { + return nil, errors.Errorf("Client ID is required") + } + + if client_id, err := url.Parse(*cids); err != nil || client_id.Host == "" { + return nil, errors.Errorf("Invalid client ID") + } else { + c.ClientId = client_id + } + + if c.UpstreamURL == "" { + return nil, errors.Errorf("Upstream URL is required") + } + + if idpu == nil { + return nil, errors.Errorf("IDP url is required") } + + if u, err := url.Parse(*idpu); err != nil { + return nil, errors.WithStack(err) + } else { + c.IdProviderURL = u + + if h := HostFromURL(u.String()); c.IsBootstrap && (h != "localhost" && h != "127.0.0.1") { + return nil, errors.Errorf("IdP must be set to localhost for bootstrap") + } + } + + return c, nil } func main() { - cfg := parseConfig() - h := NewCautiousHTTPClient() - - v := NewKeyValidator(cfg.PKISubject) - v.LoadRootPEM(cfg.TrustedCACert) + cfg, err := parseConfig() + if err != nil { + glog.Fatalln("ParseConfig", err) + return + } - idpc, err := FetchIdPConfig(h, cfg.IDProviderURL) + hidp, err := NewCautiousHTTPClient(cfg.IsBootstrap) if err != nil { - fmt.Printf("%s\n", err) + glog.Fatalln("Error building http client", err) return } - jwks, err := FetchJWKS(h, idpc.JwksUri, v) + idpc, err := FetchIdPConfig(hidp, cfg.IdProviderURL) if err != nil { - fmt.Printf("%s\n", err) + glog.Fatalln("FetchIdPConfig:", err) return } - jv := NewJWSValidator(jwks, idpc.Issuer, cfg.ClientID, cfg.ClockSkew, cfg.MaxLiftetime) + cfg.IdProviderAuthEndpoint = idpc.AuthorizationEndpoint.AsURL() - nonce := "ofspmfjuvoswhhde" - raw_jwt := "eyJ0eXAiOiJKV1MiLCJhbGciOiJQUzI1NiIsImtpZCI6IjEifQ.eyJub25jZSI6IjM0MjlhMjAyYzU4ZDkyYjQwNjNjOWM4MWM2MjQyNGRlNzBkMmIzZDQ4MmVlNDFhOTdjYmNhZjEwZDk5MWFiOTMiLCJpc3MiOiJpZHAuY3J1dGUubWU6NDQzIiwiaWF0IjoxNTA0NTc2Mzc0LCJuYmYiOjE1MDQ1NzYzNzQsImV4cCI6MTUwNDY2Mjc3NCwic3ViIjoibWNydXRlIiwiYXVkIjoidGVzdC5jcnV0ZS5tZTo0NDMifQ.iizlNfY1Vg7d-XRmgyYuhpNkNrOGaT9OOgO0HdjBozOWMvKzBTtATbIfoWOrNH6DiFY1as8uy3I1Pxnkrb8Ti8_cLDQeLxOv9klAbnebeuPI_wtZ0iwSUnSWaYzN6I6sqcEjHX3fibFvAQhO5dNDzSwONjw4AvcdpZKh579FO1sAvIw-1DmMyPSUun7rbC0Kf1Jtdlr3q7tOp3wdI_erkstxCNPwyuv7X1J7uetsu0BeJS25C2DxeB03BPEIUoo_C1xvcqikfSLLpoFcyToYiS-R9o-WpRjGid_yug65J5ALn2aM3vhe9rRbydKVm_omGL8-Etj06zbqM0Y6OrJUgA" - claims, err := jv.Validate(raw_jwt, nonce) + h, err := NewCautiousHTTPClient(false) if err != nil { - fmt.Printf("Error validating: %s\n", err) + glog.Fatalln("Error building http client", err) return } - fmt.Printf("Valid JWT for: %+v\n", claims.Subject) + kf := NewJWKSFetcher(h, idpc.JwksUri.AsURL(), idpc.Issuer, cfg.TrustedCACert) - return + cfg.jwsValidator = NewJWSValidator(&JWSValidationContext{ + KeyFetcher: kf, + Issuer: idpc.Issuer, + ClientId: cfg.ClientId, + ClockSkew: cfg.ClockSkew, + MaxLiftetime: cfg.MaxLiftetime, + }) cfg.reverseProxy = httputil.NewSingleHostReverseProxy(URLMustParse(cfg.UpstreamURL)) - if cfg.IsOptional { - http.HandleFunc("/.oidc/login", func(w http.ResponseWriter, r *http.Request) { - LoginController(w, - r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) - }) - } + http.HandleFunc("/.oidc/login", func(w http.ResponseWriter, r *http.Request) { + RedirectToIdP(w, + r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg)), "/") + }) http.HandleFunc("/.oidc/logout", func(w http.ResponseWriter, r *http.Request) { LogoutController(w, @@ -322,5 +374,14 @@ func main() { r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) }) - log.Fatal(http.ListenAndServe(cfg.ListenOn, nil)) + go http.ListenAndServe(cfg.ListenOn, + handlers.LoggingHandler(os.Stdout, http.DefaultServeMux)) + + // This has to happen last in-case we're boostrapping a proxy for the IdP itself + if err := kf.Fetch(); err != nil { + glog.Fatalln("FetchJWKS:", err) + return + } else { + kf.Run() + } } diff --git a/oidc_proxy b/oidc_proxy index e5df267..cc3c7be 100755 Binary files a/oidc_proxy and b/oidc_proxy differ diff --git a/util.go b/util.go index 10709e2..7385dfd 100644 --- a/util.go +++ b/util.go @@ -1,6 +1,8 @@ package main import ( + "crypto/sha256" + "encoding/hex" "net/url" "strings" ) @@ -41,3 +43,19 @@ func URLMustParse(u string) *url.URL { func CompareUpper(lhs, rhs string) bool { return strings.ToUpper(lhs) == strings.ToUpper(rhs) } + +func HostFromURL(u string) string { + o, err := url.Parse(u) + if err != nil { + return "" + } + + h := strings.Split(o.Host, ":") + return h[0] +} + +func Sha256Hex(v string) string { + s256 := sha256.New() + s256.Write([]byte(v)) + return hex.EncodeToString(s256.Sum(nil)) +} -- cgit v1.2.3