summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Crute <mike@crute.us>2017-09-19 04:39:36 +0000
committerMike Crute <mike@crute.us>2017-09-19 04:39:36 +0000
commit9f7861ffe1397da514606b189f5b3e383f4e7ed7 (patch)
tree2bd145745efba52ac136166e4f4535cfd59359ea
parentb7867d9cf5b0dd175b8167a552b830ebfe47d0ed (diff)
downloadoidc_proxy-9f7861ffe1397da514606b189f5b3e383f4e7ed7.tar.bz2
oidc_proxy-9f7861ffe1397da514606b189f5b3e383f4e7ed7.tar.xz
oidc_proxy-9f7861ffe1397da514606b189f5b3e383f4e7ed7.zip
Finish out most of the proxy functionality
-rw-r--r--cautious_http_client.go87
-rw-r--r--jwks_fetcher.go118
-rw-r--r--jws_validator.go110
-rw-r--r--key_validator.go36
-rw-r--r--main.go409
-rwxr-xr-xoidc_proxybin6495319 -> 7509902 bytes
-rw-r--r--util.go18
7 files changed, 534 insertions, 244 deletions
diff --git a/cautious_http_client.go b/cautious_http_client.go
index 2f33ae0..34b736f 100644
--- a/cautious_http_client.go
+++ b/cautious_http_client.go
@@ -2,24 +2,29 @@ package main
2 2
3import ( 3import (
4 "encoding/json" 4 "encoding/json"
5 "fmt" 5 "github.com/lox/httpcache"
6 "github.com/pkg/errors"
6 "net" 7 "net"
7 "net/http" 8 "net/http"
8 "net/url" 9 "net/url"
10 "strings"
9 "time" 11 "time"
10) 12)
11 13
12type CautiousHTTPClient interface { 14type CautiousHTTPClient interface {
13 Get(string) (*http.Response, error) 15 Get(string) (*http.Response, error)
14 GetJSON(string, interface{}) error 16 GetJSON(string, interface{}) error
17 GetJSONExpires(string, interface{}) (time.Duration, error)
15} 18}
16 19
17type cautiousHttpClient struct { 20type cautiousHttpClient struct {
18 client *http.Client 21 allowHttp bool
22 client *http.Client
19} 23}
20 24
21func NewCautiousHTTPClient() CautiousHTTPClient { 25// allowHttp is UNSAFE and technically validates the spec but it does make it
22 // May Need: TLSClientConfig *tls.Config 26// easier to work in dev so leaving it in for now
27func NewCautiousHTTPClient(allowHttp bool) (CautiousHTTPClient, error) {
23 CautiousTransport := &http.Transport{ 28 CautiousTransport := &http.Transport{
24 Proxy: http.ProxyFromEnvironment, 29 Proxy: http.ProxyFromEnvironment,
25 DialContext: (&net.Dialer{ 30 DialContext: (&net.Dialer{
@@ -36,44 +41,100 @@ func NewCautiousHTTPClient() CautiousHTTPClient {
36 } 41 }
37 42
38 return &cautiousHttpClient{ 43 return &cautiousHttpClient{
44 allowHttp: allowHttp,
39 client: &http.Client{ 45 client: &http.Client{
40 Transport: CautiousTransport, 46 Transport: CautiousTransport,
41 Timeout: 30 * time.Second, 47 Timeout: 30 * time.Second,
42 }, 48 },
43 } 49 }, nil
44} 50}
45 51
46func (c *cautiousHttpClient) Get(gurl string) (*http.Response, error) { 52func (c *cautiousHttpClient) Get(gurl string) (*http.Response, error) {
47 u, err := url.Parse(gurl) 53 u, err := url.Parse(gurl)
48 if err != nil { 54 if err != nil {
49 return nil, err 55 return nil, errors.WithStack(err)
50 } 56 }
51 57
52 // TODO 58 if u.Scheme != "https" && !c.allowHttp {
53 if u.Scheme != "https" && false { 59 return nil, errors.Errorf("URL for GET must be secure")
54 return nil, fmt.Errorf("URL for GET must be secure")
55 } 60 }
56 61
57 r, err := c.client.Get(u.String()) 62 r, err := c.client.Get(u.String())
58 if err != nil { 63 if err != nil {
59 return nil, err 64 return nil, errors.WithStack(err)
60 } 65 }
61 r.Body = http.MaxBytesReader(nil, r.Body, 1000000) 66 r.Body = http.MaxBytesReader(nil, r.Body, 1000000)
62 return r, err 67
68 return r, nil
63} 69}
64 70
65func (c *cautiousHttpClient) GetJSON(url string, rv interface{}) error { 71func (c *cautiousHttpClient) GetJSON(url string, rv interface{}) error {
66 r, err := c.Get(url) 72 r, err := c.Get(url)
67 if err != nil { 73 if err != nil {
68 return err 74 return errors.WithStack(err)
69 } 75 }
70 defer r.Body.Close() 76 defer r.Body.Close()
71 77
72 d := json.NewDecoder(r.Body) 78 d := json.NewDecoder(r.Body)
73 err = d.Decode(rv) 79 err = d.Decode(rv)
74 if err != nil { 80 if err != nil {
75 return err 81 return errors.WithStack(err)
76 } 82 }
77 83
78 return nil 84 return nil
79} 85}
86
87func (c *cautiousHttpClient) GetJSONExpires(url string, rv interface{}) (time.Duration, error) {
88 r, err := c.Get(url)
89 if err != nil {
90 return time.Duration(0), errors.WithStack(err)
91 }
92 defer r.Body.Close()
93
94 res := httpcache.NewResource(r.StatusCode, nil, r.Header)
95
96 d := json.NewDecoder(r.Body)
97 err = d.Decode(rv)
98 if err != nil {
99 return time.Duration(0), errors.WithStack(err)
100 }
101
102 return refreshAfter(res), nil
103}
104
105type JSONURL struct {
106 *url.URL
107}
108
109func (u *JSONURL) AsURL() *url.URL {
110 return u.URL
111}
112
113func (u *JSONURL) UnmarshalJSON(data []byte) error {
114 d := strings.Trim(string(data), "\"")
115 pu, err := url.Parse(d)
116 if err != nil {
117 return errors.WithStack(err)
118 }
119
120 u.URL = pu
121 return nil
122}
123
124func refreshAfter(res *httpcache.Resource) time.Duration {
125 maxAge, err := res.MaxAge(false)
126 if err != nil {
127 return time.Duration(0)
128 }
129
130 age, err := res.Age()
131 if err != nil {
132 return time.Duration(0)
133 }
134
135 if hFresh := res.HeuristicFreshness(); hFresh > maxAge {
136 maxAge = hFresh
137 }
138
139 return maxAge - age
140}
diff --git a/jwks_fetcher.go b/jwks_fetcher.go
new file mode 100644
index 0000000..9925430
--- /dev/null
+++ b/jwks_fetcher.go
@@ -0,0 +1,118 @@
1package main
2
3import (
4 "github.com/pkg/errors"
5 "gopkg.in/square/go-jose.v2"
6 "log"
7 "net/url"
8 "time"
9)
10
11const (
12 REQUEST_BUFFER_SIZE = 10
13 KEY_MAP_INITIAL_SIZE = 5
14 DEFAULT_REFRESH_INTERVAL = 15 * time.Minute
15 MIN_REFRESH_INTERVAL = 1 * time.Minute
16)
17
18type KeyRequest struct {
19 KeyId string
20 Response chan *jose.JSONWebKey
21}
22
23type JWKSFetcher interface {
24 Run()
25 Fetch() error
26 GetKey(string) (*jose.JSONWebKey, error)
27 Done()
28}
29
30type jwksFetcher struct {
31 keyMap map[string]jose.JSONWebKey
32 httpClient CautiousHTTPClient
33 validator KeyValidator
34 fetchTimer *time.Timer
35 url *url.URL
36 requests chan *KeyRequest
37 done chan bool
38}
39
40func NewJWKSFetcher(h CautiousHTTPClient, url *url.URL, issuer string, root string) JWKSFetcher {
41 val := NewKeyValidator(HostFromURL(issuer))
42 val.LoadRootPEM(root)
43
44 return &jwksFetcher{
45 httpClient: h,
46 validator: val,
47 url: url,
48 fetchTimer: time.NewTimer(DEFAULT_REFRESH_INTERVAL),
49 requests: make(chan *KeyRequest, REQUEST_BUFFER_SIZE),
50 keyMap: make(map[string]jose.JSONWebKey, KEY_MAP_INITIAL_SIZE),
51 done: make(chan bool),
52 }
53}
54
55func (f *jwksFetcher) Fetch() error {
56 var jwks jose.JSONWebKeySet
57 timeout, err := f.httpClient.GetJSONExpires(f.url.String(), &jwks)
58 if err != nil {
59 return errors.WithStack(err)
60 }
61
62 for _, k := range jwks.Keys {
63 err = f.validator.Validate(k)
64 if err == nil {
65 f.keyMap[k.KeyID] = k
66 } else {
67 log.Printf("Rejecting key %q because %q", k.KeyID, err)
68 }
69 }
70
71 if timeout < MIN_REFRESH_INTERVAL {
72 timeout = MIN_REFRESH_INTERVAL
73 }
74
75 success := f.fetchTimer.Reset(timeout)
76 if !success {
77 f.fetchTimer = time.NewTimer(timeout)
78 }
79
80 return nil
81}
82
83func (f *jwksFetcher) Run() {
84 for {
85 select {
86 // Incoming request for a key, return key or nil in no key
87 case r := <-f.requests:
88 if v, ok := f.keyMap[r.KeyId]; ok {
89 r.Response <- &v
90 } else {
91 r.Response <- nil
92 }
93 case <-f.fetchTimer.C:
94 f.Fetch()
95 case <-f.done:
96 return
97 }
98 }
99}
100
101func (f *jwksFetcher) Done() {
102 f.done <- true
103}
104
105func (f *jwksFetcher) GetKey(kid string) (*jose.JSONWebKey, error) {
106 r := &KeyRequest{
107 KeyId: kid,
108 Response: make(chan *jose.JSONWebKey),
109 }
110
111 f.requests <- r
112
113 if res := <-r.Response; res == nil {
114 return nil, errors.Errorf("Key not found for ID")
115 } else {
116 return res, nil
117 }
118}
diff --git a/jws_validator.go b/jws_validator.go
index e77c026..0b2467f 100644
--- a/jws_validator.go
+++ b/jws_validator.go
@@ -1,103 +1,133 @@
1package main 1package main
2 2
3import ( 3import (
4 "crypto/sha256" 4 "github.com/pkg/errors"
5 "encoding/hex"
6 "fmt"
7 "gopkg.in/square/go-jose.v2" 5 "gopkg.in/square/go-jose.v2"
8 "gopkg.in/square/go-jose.v2/jwt" 6 "gopkg.in/square/go-jose.v2/jwt"
7 "net/url"
9 "time" 8 "time"
10) 9)
11 10
11// TODO
12// validate amr claim contains requested acr values (selective_mfa will be just mfa)
13// validate acr claim is the same as requested acr_values
14//
15// acr_values can be mfa or selective_mfa (mfa only for external users)
16// mfa amr values:
17// pas - password
18// otp - OTP code
19// u2f - U2F code
20// mfa - multi-factor
21// hrd - hardware OTP device used
22// sft - software OTP device used
23
12type Claims struct { 24type Claims struct {
13 Nonce string `json:"nonce,omitempty"` 25 Nonce string `json:"nonce,omitempty"`
14 jwt.Claims 26 jwt.Claims
15} 27}
16 28
29type JWSValidationContext struct {
30 KeyFetcher JWKSFetcher
31 Issuer string
32 ClientId *url.URL
33 ClockSkew time.Duration
34 MaxLiftetime time.Duration
35}
36
17type JWSValidator interface { 37type JWSValidator interface {
18 Validate(string, string) (*Claims, error) 38 Validate(string, string) (*Claims, error)
19} 39}
20 40
21type jwsValidator struct { 41type jwsValidator struct {
22 algorithms *stringSet 42 algorithms *stringSet
23 jwks map[string]jose.JSONWebKey 43 jwks JWKSFetcher
24 issuer string 44 issuer string
25 clientID string 45 clientId *url.URL
26 clockSkew time.Duration 46 clockSkew time.Duration
27 maxLifetime time.Duration 47 maxLifetime time.Duration
28} 48}
29 49
30// TODO 50func NewJWSValidator(c *JWSValidationContext) JWSValidator {
31// validate amr claim contains requested acr values (selective_mfa will be just mfa)
32// validate acr claim is the same as requested acr_values
33func NewJWSValidator(jwks map[string]jose.JSONWebKey, issuer string, client_id string, skew time.Duration, max_life time.Duration) JWSValidator {
34 return &jwsValidator{ 51 return &jwsValidator{
35 algorithms: NewStringSet("PS256", "PS385", "PS512"), 52 algorithms: NewStringSet("PS256", "PS385", "PS512"),
36 jwks: jwks, 53 jwks: c.KeyFetcher,
37 issuer: issuer, 54 issuer: c.Issuer,
38 clientID: client_id, 55 clientId: c.ClientId,
39 clockSkew: skew, 56 clockSkew: c.ClockSkew,
40 maxLifetime: max_life, 57 maxLifetime: c.MaxLiftetime,
41 } 58 }
42} 59}
43 60
44func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) { 61func (v *jwsValidator) Validate(j string, nonce string) (*Claims, error) {
45 parsed_jwt, err := jwt.ParseSigned(j) 62 parsed, err := jwt.ParseSigned(j)
63 if err != nil {
64 return nil, errors.WithStack(err)
65 }
66
67 if err := v.validateHeaders(parsed.Headers); err != nil {
68 return nil, errors.WithStack(err)
69 }
70
71 kid := parsed.Headers[0].KeyID
72 key, err := v.jwks.GetKey(kid)
73 if err != nil {
74 return nil, errors.WithStack(err)
75 }
76
77 claims, err := v.validateClaims(parsed, key)
46 if err != nil { 78 if err != nil {
47 return nil, err 79 return nil, errors.WithStack(err)
48 } 80 }
49 81
50 if len(parsed_jwt.Headers) != 1 { 82 if err := v.validateNonce(nonce, claims.Nonce); err != nil {
51 return nil, fmt.Errorf("Invalid signature count") 83 return nil, errors.WithStack(err)
52 } 84 }
53 85
54 head := parsed_jwt.Headers[0] 86 return claims, nil
87}
55 88
56 if !v.algorithms.Contains(head.Algorithm) { 89func (v *jwsValidator) validateHeaders(h []jose.Header) error {
57 return nil, fmt.Errorf("Invalid signature algorithm") 90 if len(h) != 1 {
91 return errors.Errorf("Invalid signature count")
58 } 92 }
59 93
60 if typ, ok := head.ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" { 94 if !v.algorithms.Contains(h[0].Algorithm) {
61 return nil, fmt.Errorf("Invalid token type") 95 return errors.Errorf("Invalid signature algorithm")
62 } 96 }
63 97
64 key, ok := v.jwks[head.KeyID] 98 if typ, ok := h[0].ExtraHeaders[jose.HeaderType]; !ok || typ != "JWS" {
65 if !ok { 99 return errors.Errorf("Invalid token type")
66 return nil, fmt.Errorf("No key found for key id")
67 } 100 }
68 101
102 return nil
103}
104
105func (v *jwsValidator) validateClaims(j *jwt.JSONWebToken, k *jose.JSONWebKey) (*Claims, error) {
69 claims := &Claims{} 106 claims := &Claims{}
70 if err = parsed_jwt.Claims(key, claims); err != nil { 107 if err := j.Claims(k, claims); err != nil {
71 return nil, err 108 return nil, errors.WithStack(err)
72 } 109 }
73 110
74 exp := jwt.Expected{ 111 exp := jwt.Expected{
75 Issuer: v.issuer, 112 Issuer: v.issuer,
76 Audience: jwt.Audience{v.clientID}, 113 Audience: jwt.Audience{v.clientId.String()},
77 Time: time.Now(), 114 Time: time.Now(),
78 } 115 }
79 116
80 if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil { 117 if err := claims.ValidateWithLeeway(exp, v.clockSkew); err != nil {
81 return nil, err 118 return nil, errors.WithStack(err)
82 } 119 }
83 120
84 if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) { 121 if claims.IssuedAt.Time().Add(v.maxLifetime).Before(time.Now()) {
85 return nil, fmt.Errorf("Token exceeded max lifetime") 122 return nil, errors.Errorf("Token exceeded max lifetime")
86 }
87
88 if err = v.validateNonce(nonce, claims.Nonce); err != nil {
89 return nil, err
90 } 123 }
91 124
92 return claims, nil 125 return claims, nil
93} 126}
94 127
95func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error { 128func (v *jwsValidator) validateNonce(nonce string, token_nonce string) error {
96 s256 := sha256.New() 129 if token_nonce != Sha256Hex(nonce) {
97 s256.Write([]byte(nonce)) 130 return errors.Errorf("Invalid nonce: %s = %q vs %q", nonce, token_nonce, Sha256Hex(nonce))
98 hashed_nonce := hex.EncodeToString(s256.Sum(nil))
99 if token_nonce != hashed_nonce {
100 return fmt.Errorf("Invalid nonce")
101 } 131 }
102 132
103 return nil 133 return nil
diff --git a/key_validator.go b/key_validator.go
index fe6eb7b..062d78c 100644
--- a/key_validator.go
+++ b/key_validator.go
@@ -4,11 +4,13 @@ import (
4 "crypto/rsa" 4 "crypto/rsa"
5 "crypto/x509" 5 "crypto/x509"
6 "encoding/pem" 6 "encoding/pem"
7 "fmt" 7 "github.com/pkg/errors"
8 "gopkg.in/square/go-jose.v2" 8 "gopkg.in/square/go-jose.v2"
9 "io/ioutil" 9 "io/ioutil"
10) 10)
11 11
12// TODO: CRL validation
13
12type KeyValidator interface { 14type KeyValidator interface {
13 Validate(jose.JSONWebKey) error 15 Validate(jose.JSONWebKey) error
14 LoadRootPEM(string) error 16 LoadRootPEM(string) error
@@ -31,17 +33,17 @@ func NewKeyValidator(subject string) KeyValidator {
31func (v *keyValidator) LoadRootPEM(filename string) error { 33func (v *keyValidator) LoadRootPEM(filename string) error {
32 pem_data, err := ioutil.ReadFile(filename) 34 pem_data, err := ioutil.ReadFile(filename)
33 if err != nil { 35 if err != nil {
34 return err 36 return errors.WithStack(err)
35 } 37 }
36 38
37 pem_block, _ := pem.Decode(pem_data) 39 pem_block, _ := pem.Decode(pem_data)
38 if pem_block == nil { 40 if pem_block == nil {
39 return fmt.Errorf("PEM decode failed") 41 return errors.Errorf("PEM decode failed")
40 } 42 }
41 43
42 cert, err := x509.ParseCertificate(pem_block.Bytes) 44 cert, err := x509.ParseCertificate(pem_block.Bytes)
43 if err != nil { 45 if err != nil {
44 return err 46 return errors.WithStack(err)
45 } 47 }
46 48
47 v.roots.AddCert(cert) 49 v.roots.AddCert(cert)
@@ -52,40 +54,40 @@ func (v *keyValidator) LoadRootPEM(filename string) error {
52func (v *keyValidator) Validate(key jose.JSONWebKey) error { 54func (v *keyValidator) Validate(key jose.JSONWebKey) error {
53 pk, ok := key.Key.(*rsa.PublicKey) 55 pk, ok := key.Key.(*rsa.PublicKey)
54 if !ok { 56 if !ok {
55 return fmt.Errorf("Key type is not RSA") 57 return errors.Errorf("Key type is not RSA")
56 } 58 }
57 59
58 if !v.algorithms.Contains(key.Algorithm) { 60 if !v.algorithms.Contains(key.Algorithm) {
59 return fmt.Errorf("Key algorithm is not supported") 61 return errors.Errorf("Key algorithm is not supported")
60 } 62 }
61 63
62 cert := key.Certificates[0] 64 cert := key.Certificates[0]
63 cpk, ok := cert.PublicKey.(*rsa.PublicKey) 65 cpk, ok := cert.PublicKey.(*rsa.PublicKey)
64 if !ok { 66 if !ok {
65 return fmt.Errorf("Public key is not RSA") 67 return errors.Errorf("Public key is not RSA")
66 } 68 }
67 69
68 if cpk.N.BitLen() < 2048 { 70 if cpk.N.BitLen() < 2048 {
69 return fmt.Errorf("Key length less than 2048 bits") 71 return errors.Errorf("Key length less than 2048 bits")
70 } 72 }
71 73
72 if cert.KeyUsage&x509.KeyUsageDigitalSignature != 1 { 74 if cert.KeyUsage&x509.KeyUsageDigitalSignature != 1 {
73 return fmt.Errorf("Certificate not valid for digital signatures") 75 return errors.Errorf("Certificate not valid for digital signatures")
74 } 76 }
75 77
76 err := v.validateCertificateChain(key.Certificates) 78 err := v.validateCertificateChain(key.Certificates)
77 if err != nil { 79 if err != nil {
78 return err 80 return errors.WithStack(err)
79 } 81 }
80 82
81 err = v.validateCertificateCRL(cert) 83 err = v.validateCertificateCRL(cert)
82 if err != nil { 84 if err != nil {
83 return err 85 return errors.WithStack(err)
84 } 86 }
85 87
86 err = v.validatePublicKeyInCertificate(pk, cpk) 88 err = v.validatePublicKeyInCertificate(pk, cpk)
87 if err != nil { 89 if err != nil {
88 return err 90 return errors.WithStack(err)
89 } 91 }
90 92
91 return nil 93 return nil
@@ -116,15 +118,15 @@ func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error
116 118
117 chains, err := chain[0].Verify(vo) 119 chains, err := chain[0].Verify(vo)
118 if err != nil { 120 if err != nil {
119 return err 121 return errors.WithStack(err)
120 } 122 }
121 123
122 if len(chains) <= 0 { 124 if len(chains) <= 0 {
123 return fmt.Errorf("No valid certificate chains found") 125 return errors.Errorf("No valid certificate chains found")
124 } 126 }
125 127
126 if chain[0].Subject.CommonName != v.pkiSubject { 128 if chain[0].Subject.CommonName != v.pkiSubject {
127 return fmt.Errorf("Invalid certificate subject name") 129 return errors.Errorf("Invalid certificate subject name")
128 } 130 }
129 131
130 return nil 132 return nil
@@ -133,11 +135,11 @@ func (v *keyValidator) validateCertificateChain(chain []*x509.Certificate) error
133// validate first item of x5c matches n and e 135// validate first item of x5c matches n and e
134func (v *keyValidator) validatePublicKeyInCertificate(pk *rsa.PublicKey, cpk *rsa.PublicKey) error { 136func (v *keyValidator) validatePublicKeyInCertificate(pk *rsa.PublicKey, cpk *rsa.PublicKey) error {
135 if cpk.E != pk.E { 137 if cpk.E != pk.E {
136 return fmt.Errorf("E in key and E in cert do not match") 138 return errors.Errorf("E in key and E in cert do not match")
137 } 139 }
138 140
139 if pk.N.Cmp(cpk.N) != 0 { 141 if pk.N.Cmp(cpk.N) != 0 {
140 return fmt.Errorf("N in key and N in cert do not match") 142 return errors.Errorf("N in key and N in cert do not match")
141 } 143 }
142 144
143 return nil 145 return nil
diff --git a/main.go b/main.go
index 965e72c..805c40d 100644
--- a/main.go
+++ b/main.go
@@ -4,53 +4,52 @@ import (
4 "context" 4 "context"
5 "crypto/rand" 5 "crypto/rand"
6 "encoding/hex" 6 "encoding/hex"
7 "fmt" 7 "flag"
8 "gopkg.in/square/go-jose.v2" 8 "github.com/golang/glog"
9 "log" 9 "github.com/gorilla/handlers"
10 "github.com/pkg/errors"
10 "net/http" 11 "net/http"
11 "net/http/httputil" 12 "net/http/httputil"
12 "net/url" 13 "net/url"
14 "os"
15 "strconv"
13 "strings" 16 "strings"
14 "time" 17 "time"
15) 18)
16 19
17const ( 20const (
18 NONCE_SIZE int = 16 21 NONCE_SIZE = 16
19 TOKEN_COOKIE_NAME string = "sso_token" 22 TOKEN_COOKIE_NAME = "sso_token"
20 RFP_COOKIE_NAME string = "sso_rfp" 23 RFP_COOKIE_NAME = "sso_rfp"
24 DEFAULT_CLOCK_SKEW = 5 * time.Minute
25 DEFAULT_MAX_LIFETIME = 24 * time.Hour
26 DEFAULT_COOKIE_EXP = 48 * time.Hour
21) 27)
22 28
23// TODO: Enable https checks in HTTP client 29// TODO: MFA support
24
25// acr_values can be mfa or selective_mfa (mfa only for external users)
26// mfa amr values:
27// pas - password
28// otp - OTP code
29// u2f - U2F code
30// mfa - multi-factor
31// hrd - hardware OTP device used
32// sft - software OTP device used
33 30
34type ProxyConfig struct { 31type ProxyConfig struct {
35 IDProviderURL string 32 IdProviderURL *url.URL
36 ClientID string 33 IdProviderAuthEndpoint *url.URL
37 UpstreamURL string 34 ClientId *url.URL
38 ListenOn string 35 UpstreamURL string
39 TrustedCACert string 36 ListenOn string
40 PKISubject string // TODO: Should be same as IDP w/out scheme and port 37 TrustedCACert string
41 ClockSkew time.Duration 38 ClockSkew time.Duration
42 MaxLiftetime time.Duration 39 MaxLiftetime time.Duration
43 IsOptional bool 40 IsOptional bool
44 RequestMFA bool 41 IsBootstrap bool
45 AllowedMFAMethods []string // An OR set 42 RequestMFA bool
46 RequiredMFAMethods []string // An AND set 43 AllowedMFAMethods []string // An OR set
47 reverseProxy *httputil.ReverseProxy 44 RequiredMFAMethods []string // An AND set
45 reverseProxy *httputil.ReverseProxy
46 jwsValidator JWSValidator
48} 47}
49 48
50type IdPConfig struct { 49type IdPConfig struct {
51 AuthorizationEndpoint string `json:"authorization_endpoint"` 50 AuthorizationEndpoint *JSONURL `json:"authorization_endpoint"`
51 JwksUri *JSONURL `json:"jwks_uri"`
52 Issuer string `json:"issuer"` 52 Issuer string `json:"issuer"`
53 JwksUri string `json:"jwks_uri"`
54 GrantTypes []string `json:"grant_types_supported"` 53 GrantTypes []string `json:"grant_types_supported"`
55 IdTokenSigningAlgs []string `json:"id_token_signing_alg_values_supported"` 54 IdTokenSigningAlgs []string `json:"id_token_signing_alg_values_supported"`
56 ResponseModes []string `json:"response_modes_supported"` 55 ResponseModes []string `json:"response_modes_supported"`
@@ -59,258 +58,311 @@ type IdPConfig struct {
59 SubjectTypes []string `json:"subject_types_supported"` 58 SubjectTypes []string `json:"subject_types_supported"`
60} 59}
61 60
62// TODO: Optimization to fetch only if expired (per http headers) 61func FetchIdPConfig(h CautiousHTTPClient, u *url.URL) (*IdPConfig, error) {
63func FetchIdPConfig(h CautiousHTTPClient, idp_url string) (*IdPConfig, error) { 62 u = URLMustParse(u.String())
64 u, err := url.Parse(idp_url)
65 if err != nil {
66 return nil, err
67 }
68 u.Path = "/.well-known/openid-configuration" 63 u.Path = "/.well-known/openid-configuration"
69 64
70 var idpc IdPConfig 65 var idpc IdPConfig
71 err = h.GetJSON(u.String(), &idpc) 66 err := h.GetJSON(u.String(), &idpc)
72 if err != nil { 67 if err != nil {
73 return nil, err 68 return nil, errors.WithStack(err)
74 } 69 }
75 70
76 return &idpc, nil 71 return &idpc, nil
77} 72}
78 73
79// TODO: Optimization to fetch only if expired (per http headers)
80func FetchJWKS(h CautiousHTTPClient, jwks_url string, val KeyValidator) (map[string]jose.JSONWebKey, error) {
81 var jwks jose.JSONWebKeySet
82 err := h.GetJSON(jwks_url, &jwks)
83 if err != nil {
84 return nil, err
85 }
86
87 keys := make(map[string]jose.JSONWebKey, len(jwks.Keys))
88
89 for _, k := range jwks.Keys {
90 err = val.Validate(k)
91 if err == nil {
92 keys[k.KeyID] = k
93 }
94 }
95
96 return keys, nil
97}
98
99func GenerateNonce() (string, error) { 74func GenerateNonce() (string, error) {
100 nonce := make([]byte, NONCE_SIZE) 75 nonce := make([]byte, NONCE_SIZE)
101 n, err := rand.Read(nonce) 76 n, err := rand.Read(nonce)
102 if n != NONCE_SIZE || err != nil { 77 if n != NONCE_SIZE || err != nil {
103 return "", err 78 return "", errors.WithStack(err)
104 } 79 }
105 return hex.EncodeToString(nonce), nil 80 return hex.EncodeToString(nonce), nil
106} 81}
107 82
108// TODO 83func SetSecureCookie(w http.ResponseWriter, name string, value string, exp time.Duration) {
109// Cookie rules 84 http.SetCookie(w, &http.Cookie{
110// Secure 85 Name: name,
111// HttpOnly 86 Value: value,
112// Path to / 87 Expires: time.Now().Add(exp),
113// Expires to iat in JWT 88 HttpOnly: true,
114func SetCookie() { 89 Secure: true,
90 Path: "/",
91 })
115} 92}
116 93
117// TODO 94func ExpireCookie(w http.ResponseWriter, name string) {
118func MakeClientID(r *http.Request) string { 95 http.SetCookie(w, &http.Cookie{
119 if strings.Contains(r.Host, ":") { 96 Name: name,
120 return r.Host 97 Value: "",
121 } 98 Expires: time.Now().Add(-1 * time.Hour),
122 return "" 99 HttpOnly: true,
100 Secure: true,
101 Path: "/",
102 MaxAge: 0,
103 })
123} 104}
124 105
125// TODO 106func RedirectToIdP(w http.ResponseWriter, r *http.Request, path string) {
126func RedirectToIDP(w http.ResponseWriter, r *http.Request) { 107 ctx := r.Context().Value("ProxyConfig").(*ProxyConfig)
127 nonce, _ := GenerateNonce()
128 _ = nonce
129 nonceh := "" // SHA256 nonce
130 108
131 // Set nonce cookie 109 nonce, err := GenerateNonce()
110 if err != nil {
111 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
112 return
113 }
114
115 SetSecureCookie(w, RFP_COOKIE_NAME, nonce, DEFAULT_COOKIE_EXP)
116
117 rt := ""
118 rp := r.URL.Query().Get("redirect_uri")
119 if rp != "" {
120 rt = rp
121 } else {
122 ru := &url.URL{
123 Scheme: "https",
124 Host: r.Host,
125 Path: path,
126 }
127 rt = ru.String()
128 }
132 129
133 req := url.Values{} 130 req := url.Values{}
134 req.Add("client_id", "") // fqdn + : + port 131 req.Add("client_id", ctx.ClientId.String())
135 req.Add("nonce", nonceh) 132 req.Add("nonce", Sha256Hex(nonce))
136 req.Add("redirect_uri", "") // Requested URL 133 req.Add("redirect_uri", rt)
137 req.Add("scope", "openid") 134 req.Add("scope", "openid")
138 req.Add("response_type", "id_token") 135 req.Add("response_type", "id_token")
136
137 u := URLMustParse(ctx.IdProviderAuthEndpoint.String())
138 u.RawQuery = req.Encode()
139
140 http.Redirect(w, r, u.String(), http.StatusFound)
139} 141}
140 142
141// TODO: Remove id_token from URL, set cookie and redirect user to requested URL
142func SetTokenCookieAndRedirect(w http.ResponseWriter, r *http.Request, token string) { 143func SetTokenCookieAndRedirect(w http.ResponseWriter, r *http.Request, token string) {
143} 144 SetSecureCookie(w, TOKEN_COOKIE_NAME, token, DEFAULT_COOKIE_EXP)
144 145
145// TODO 146 q := r.URL.Query()
146func ValidateJWT(jwt, rfp string) bool { 147 q.Del("id_token")
147 return true 148
148} 149 u := URLMustParse(r.URL.String())
150 u.RawQuery = q.Encode()
149 151
150// TODO 152 http.Redirect(w, r, u.String(), http.StatusFound)
151func GetJWTSubject(jwt string) string {
152 return ""
153} 153}
154 154
155func RequestHasForwardedUser(w http.ResponseWriter, r *http.Request) bool { 155func RequestHasForwardedUser(w http.ResponseWriter, r *http.Request) bool {
156 if _, ok := r.Header["X-Forwarded-User"]; ok { 156 _, ok := r.Header["X-Forwarded-User"]
157 log.Printf("ERROR: Request contains X-Forwarded-For header") 157 return ok
158 http.Error(w, "Bad Request", http.StatusBadRequest)
159 return true
160 } else {
161 return false
162 }
163} 158}
164 159
165func RequestIsOverSecureChannel(w http.ResponseWriter, r *http.Request) bool { 160func RequestIsOverSecureChannel(w http.ResponseWriter, r *http.Request) bool {
166 https, ok := r.Header["X-Forwarded-Proto"] 161 https, ok := r.Header["X-Forwarded-Proto"]
167 if !ok || len(https) != 1 { 162 if !ok || len(https) != 1 {
168 log.Printf("ERROR: Request does not contain X-Forwarded-Proto header") 163 glog.V(1).Infoln("Request does not contain X-Forwarded-Proto header")
169 http.Error(w, "Bad Request", http.StatusBadRequest)
170 return false 164 return false
171 } 165 }
172 166
173 if !CompareUpper(https[0], "HTTPS") { 167 if !CompareUpper(https[0], "HTTPS") {
174 log.Printf("ERROR: Request is not over HTTPS") 168 glog.V(1).Infoln("Request is not over HTTPS")
175 http.Error(w, "Bad Request", http.StatusBadRequest)
176 return false 169 return false
177 } 170 }
178 171
179 return true 172 return true
180} 173}
181 174
182// TODO
183// - Validate Hostname header == known hostname
184func AuthProxyController(w http.ResponseWriter, r *http.Request) { 175func AuthProxyController(w http.ResponseWriter, r *http.Request) {
185 proxy := r.Context().Value("ProxyConfig").(*ProxyConfig).reverseProxy 176 ctx := r.Context().Value("ProxyConfig").(*ProxyConfig)
186 177
187 // Order matters in these checks! 178 // Order matters in these checks!
188 if RequestHasForwardedUser(w, r) { 179 if RequestHasForwardedUser(w, r) {
180 glog.Errorln("Request already has X-Forwarded-User")
181 http.Error(w, "Bad Request", http.StatusBadRequest)
189 return 182 return
190 } 183 }
191 184
192 if !RequestIsOverSecureChannel(w, r) { 185 if !RequestIsOverSecureChannel(w, r) {
186 http.Error(w, "Bad Request", http.StatusBadRequest)
193 return 187 return
194 } 188 }
195 189
196 if CompareUpper(r.Method, "OPTIONS") { 190 if CompareUpper(r.Method, "OPTIONS") {
197 proxy.ServeHTTP(w, r) 191 ctx.reverseProxy.ServeHTTP(w, r)
198 return 192 return
199 } 193 }
200 194
201 rfpc, err := r.Cookie(RFP_COOKIE_NAME) 195 rfpc, err := r.Cookie(RFP_COOKIE_NAME)
202 if err != nil { 196 if err != nil {
203 log.Printf("ERROR: No rfp cookie") 197 glog.V(1).Infoln("No rfp cookie")
204 RedirectToIDP(w, r) 198 if ctx.IsOptional {
205 return 199 ctx.reverseProxy.ServeHTTP(w, r)
200 return
201 } else {
202 RedirectToIdP(w, r, r.URL.Path)
203 return
204 }
206 } 205 }
207 206
208 token := r.URL.Query().Get("id_token") 207 if token := r.URL.Query().Get("id_token"); token != "" {
209 if token != "" && ValidateJWT(rfpc.Value, token) { 208 if _, err := ctx.jwsValidator.Validate(token, rfpc.Value); err == nil {
210 SetTokenCookieAndRedirect(w, r, token) 209 SetTokenCookieAndRedirect(w, r, token)
211 return 210 return
211 } else {
212 glog.V(1).Infof("Querystring id_token invalid: %s", err)
213 }
212 } 214 }
213 215
214 tokenc, err := r.Cookie(TOKEN_COOKIE_NAME) 216 tokenc, err := r.Cookie(TOKEN_COOKIE_NAME)
215 if err != nil { 217 if err != nil {
216 log.Printf("ERROR: No token cookie") 218 glog.V(1).Infoln("No token cookie")
217 RedirectToIDP(w, r) 219 if ctx.IsOptional {
218 return 220 ctx.reverseProxy.ServeHTTP(w, r)
221 return
222 } else {
223 RedirectToIdP(w, r, r.URL.Path)
224 return
225 }
219 } 226 }
220 227
221 if !ValidateJWT(tokenc.Value, rfpc.Value) { 228 claims, err := ctx.jwsValidator.Validate(tokenc.Value, rfpc.Value)
222 log.Printf("ERROR: Token is invalid") 229 if err != nil {
223 RedirectToIDP(w, r) 230 glog.Errorln("Token is invalid", err)
224 return 231 if ctx.IsOptional {
232 ctx.reverseProxy.ServeHTTP(w, r)
233 return
234 } else {
235 RedirectToIdP(w, r, r.URL.Path)
236 return
237 }
225 } 238 }
226 239
227 r.Header["X-Forwarded-User"] = []string{GetJWTSubject(tokenc.Value)} 240 r.Header["X-Forwarded-User"] = []string{claims.Subject}
241 r.Header["X-Forwarded-Token-Expires"] = []string{strconv.FormatInt(int64(claims.Expiry), 10)}
242
243 age := time.Since(claims.IssuedAt.Time()).Minutes()
244 r.Header["X-Forwarded-Token-Age"] = []string{strconv.FormatInt(int64(age), 10)}
228 245
229 proxy.ServeHTTP(w, r) 246 ctx.reverseProxy.ServeHTTP(w, r)
230} 247}
231 248
232// Remove token and rfp cookies and redirect user to root of domain
233func LogoutController(w http.ResponseWriter, r *http.Request) { 249func LogoutController(w http.ResponseWriter, r *http.Request) {
234 http.SetCookie(w, &http.Cookie{ 250 ExpireCookie(w, RFP_COOKIE_NAME)
235 Name: TOKEN_COOKIE_NAME, 251 ExpireCookie(w, TOKEN_COOKIE_NAME)
236 Value: "",
237 MaxAge: 0,
238 })
239
240 http.SetCookie(w, &http.Cookie{
241 Name: TOKEN_COOKIE_NAME,
242 Value: "",
243 MaxAge: 0,
244 })
245
246 http.Redirect(w, r, "/", http.StatusFound) 252 http.Redirect(w, r, "/", http.StatusFound)
247} 253}
248 254
249// TODO
250func LoginController(w http.ResponseWriter, r *http.Request) {
251}
252
253// TODO
254// Optional login allows for applications that can operate in anonymous mode or 255// Optional login allows for applications that can operate in anonymous mode or
255// authenticated mode. When in anonmyous mode the request is proxied through 256// authenticated mode. When in anonmyous mode the request is proxied through
256// without an X-Forwarded-User header. Upstream servers should either expose or 257// without an X-Forwarded-User header. Upstream servers should either expose or
257// map a URL for /.oidc/login to allow users to login. On successful login the 258// map a URL for /.oidc/login to allow users to login. On successful login the
258// user will be redirected back to the main page for the site (/) 259// user will be redirected back to the main page for the site (/)
259func parseConfig() *ProxyConfig { 260func parseConfig() (*ProxyConfig, error) {
260 return &ProxyConfig{ 261 c := &ProxyConfig{}
261 IDProviderURL: "http://mcrute-virt:9993", 262
262 ClientID: "test.crute.me:443", 263 idpu := flag.String("idp", "", "URL for ID provider")
263 UpstreamURL: "http://localhost:9991/", 264 mfam := flag.String("allow-mfa-methods", "", "Comma seperated list of allowed mfa methods")
264 ListenOn: ":9992", 265 rmfa := flag.String("require-mfa-methods", "", "Comma seperated list of required mfa methods")
265 TrustedCACert: "/home/mcrute/oidc_project/test_ca/ca_cert.pem", 266 cids := flag.String("client-id", "", "Client ID for proxy with IdP")
266 IsOptional: false, 267
267 PKISubject: "Crute OpenID Signing 1", 268 flag.BoolVar(&c.IsOptional, "optional", false, "Allow proxying of unauthenticated calls")
268 MaxLiftetime: 24 * time.Hour, 269 flag.BoolVar(&c.IsBootstrap, "bootstrap", false, "Allow running a proxy for the IdP itself")
269 ClockSkew: 5 * time.Minute, 270 flag.BoolVar(&c.RequestMFA, "mfa", false, "Request user MFA authentication from IdP")
271
272 flag.DurationVar(&c.MaxLiftetime, "max-lifetime", DEFAULT_MAX_LIFETIME, "Maximum allowed time from token issuance")
273 flag.DurationVar(&c.ClockSkew, "clock-skew", DEFAULT_CLOCK_SKEW, "Allowable IdP clock skew relative to proxy")
274
275 flag.StringVar(&c.UpstreamURL, "upstream", "", "URL of upstream service for which to proxy")
276 flag.StringVar(&c.ListenOn, "listen", ":9992", "Optional port and ip on which to listen")
277 flag.StringVar(&c.TrustedCACert, "ca", "", "Path to trusted CA certificate")
278
279 flag.Parse()
280
281 c.AllowedMFAMethods = strings.Split(*mfam, ",")
282 c.RequiredMFAMethods = strings.Split(*rmfa, ",")
283
284 if c.IsBootstrap {
285 c.IsOptional = true
286 }
287
288 if _, err := os.Stat(c.TrustedCACert); os.IsNotExist(err) {
289 return nil, errors.Errorf("CA certificate does not exist")
290 }
291
292 if cids == nil {
293 return nil, errors.Errorf("Client ID is required")
294 }
295
296 if client_id, err := url.Parse(*cids); err != nil || client_id.Host == "" {
297 return nil, errors.Errorf("Invalid client ID")
298 } else {
299 c.ClientId = client_id
300 }
301
302 if c.UpstreamURL == "" {
303 return nil, errors.Errorf("Upstream URL is required")
304 }
305
306 if idpu == nil {
307 return nil, errors.Errorf("IDP url is required")
270 } 308 }
309
310 if u, err := url.Parse(*idpu); err != nil {
311 return nil, errors.WithStack(err)
312 } else {
313 c.IdProviderURL = u
314
315 if h := HostFromURL(u.String()); c.IsBootstrap && (h != "localhost" && h != "127.0.0.1") {
316 return nil, errors.Errorf("IdP must be set to localhost for bootstrap")
317 }
318 }
319
320 return c, nil
271} 321}
272 322
273func main() { 323func main() {
274 cfg := parseConfig() 324 cfg, err := parseConfig()
275 h := NewCautiousHTTPClient() 325 if err != nil {
276 326 glog.Fatalln("ParseConfig", err)
277 v := NewKeyValidator(cfg.PKISubject) 327 return
278 v.LoadRootPEM(cfg.TrustedCACert) 328 }
279 329
280 idpc, err := FetchIdPConfig(h, cfg.IDProviderURL) 330 hidp, err := NewCautiousHTTPClient(cfg.IsBootstrap)
281 if err != nil { 331 if err != nil {
282 fmt.Printf("%s\n", err) 332 glog.Fatalln("Error building http client", err)
283 return 333 return
284 } 334 }
285 335
286 jwks, err := FetchJWKS(h, idpc.JwksUri, v) 336 idpc, err := FetchIdPConfig(hidp, cfg.IdProviderURL)
287 if err != nil { 337 if err != nil {
288 fmt.Printf("%s\n", err) 338 glog.Fatalln("FetchIdPConfig:", err)
289 return 339 return
290 } 340 }
291 341
292 jv := NewJWSValidator(jwks, idpc.Issuer, cfg.ClientID, cfg.ClockSkew, cfg.MaxLiftetime) 342 cfg.IdProviderAuthEndpoint = idpc.AuthorizationEndpoint.AsURL()
293 343
294 nonce := "ofspmfjuvoswhhde" 344 h, err := NewCautiousHTTPClient(false)
295 raw_jwt := "eyJ0eXAiOiJKV1MiLCJhbGciOiJQUzI1NiIsImtpZCI6IjEifQ.eyJub25jZSI6IjM0MjlhMjAyYzU4ZDkyYjQwNjNjOWM4MWM2MjQyNGRlNzBkMmIzZDQ4MmVlNDFhOTdjYmNhZjEwZDk5MWFiOTMiLCJpc3MiOiJpZHAuY3J1dGUubWU6NDQzIiwiaWF0IjoxNTA0NTc2Mzc0LCJuYmYiOjE1MDQ1NzYzNzQsImV4cCI6MTUwNDY2Mjc3NCwic3ViIjoibWNydXRlIiwiYXVkIjoidGVzdC5jcnV0ZS5tZTo0NDMifQ.iizlNfY1Vg7d-XRmgyYuhpNkNrOGaT9OOgO0HdjBozOWMvKzBTtATbIfoWOrNH6DiFY1as8uy3I1Pxnkrb8Ti8_cLDQeLxOv9klAbnebeuPI_wtZ0iwSUnSWaYzN6I6sqcEjHX3fibFvAQhO5dNDzSwONjw4AvcdpZKh579FO1sAvIw-1DmMyPSUun7rbC0Kf1Jtdlr3q7tOp3wdI_erkstxCNPwyuv7X1J7uetsu0BeJS25C2DxeB03BPEIUoo_C1xvcqikfSLLpoFcyToYiS-R9o-WpRjGid_yug65J5ALn2aM3vhe9rRbydKVm_omGL8-Etj06zbqM0Y6OrJUgA"
296 claims, err := jv.Validate(raw_jwt, nonce)
297 if err != nil { 345 if err != nil {
298 fmt.Printf("Error validating: %s\n", err) 346 glog.Fatalln("Error building http client", err)
299 return 347 return
300 } 348 }
301 349
302 fmt.Printf("Valid JWT for: %+v\n", claims.Subject) 350 kf := NewJWKSFetcher(h, idpc.JwksUri.AsURL(), idpc.Issuer, cfg.TrustedCACert)
303 351
304 return 352 cfg.jwsValidator = NewJWSValidator(&JWSValidationContext{
353 KeyFetcher: kf,
354 Issuer: idpc.Issuer,
355 ClientId: cfg.ClientId,
356 ClockSkew: cfg.ClockSkew,
357 MaxLiftetime: cfg.MaxLiftetime,
358 })
305 359
306 cfg.reverseProxy = httputil.NewSingleHostReverseProxy(URLMustParse(cfg.UpstreamURL)) 360 cfg.reverseProxy = httputil.NewSingleHostReverseProxy(URLMustParse(cfg.UpstreamURL))
307 361
308 if cfg.IsOptional { 362 http.HandleFunc("/.oidc/login", func(w http.ResponseWriter, r *http.Request) {
309 http.HandleFunc("/.oidc/login", func(w http.ResponseWriter, r *http.Request) { 363 RedirectToIdP(w,
310 LoginController(w, 364 r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg)), "/")
311 r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) 365 })
312 })
313 }
314 366
315 http.HandleFunc("/.oidc/logout", func(w http.ResponseWriter, r *http.Request) { 367 http.HandleFunc("/.oidc/logout", func(w http.ResponseWriter, r *http.Request) {
316 LogoutController(w, 368 LogoutController(w,
@@ -322,5 +374,14 @@ func main() {
322 r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg))) 374 r.WithContext(context.WithValue(r.Context(), "ProxyConfig", cfg)))
323 }) 375 })
324 376
325 log.Fatal(http.ListenAndServe(cfg.ListenOn, nil)) 377 go http.ListenAndServe(cfg.ListenOn,
378 handlers.LoggingHandler(os.Stdout, http.DefaultServeMux))
379
380 // This has to happen last in-case we're boostrapping a proxy for the IdP itself
381 if err := kf.Fetch(); err != nil {
382 glog.Fatalln("FetchJWKS:", err)
383 return
384 } else {
385 kf.Run()
386 }
326} 387}
diff --git a/oidc_proxy b/oidc_proxy
index e5df267..cc3c7be 100755
--- a/oidc_proxy
+++ b/oidc_proxy
Binary files differ
diff --git a/util.go b/util.go
index 10709e2..7385dfd 100644
--- a/util.go
+++ b/util.go
@@ -1,6 +1,8 @@
1package main 1package main
2 2
3import ( 3import (
4 "crypto/sha256"
5 "encoding/hex"
4 "net/url" 6 "net/url"
5 "strings" 7 "strings"
6) 8)
@@ -41,3 +43,19 @@ func URLMustParse(u string) *url.URL {
41func CompareUpper(lhs, rhs string) bool { 43func CompareUpper(lhs, rhs string) bool {
42 return strings.ToUpper(lhs) == strings.ToUpper(rhs) 44 return strings.ToUpper(lhs) == strings.ToUpper(rhs)
43} 45}
46
47func HostFromURL(u string) string {
48 o, err := url.Parse(u)
49 if err != nil {
50 return ""
51 }
52
53 h := strings.Split(o.Host, ":")
54 return h[0]
55}
56
57func Sha256Hex(v string) string {
58 s256 := sha256.New()
59 s256.Write([]byte(v))
60 return hex.EncodeToString(s256.Sum(nil))
61}