diff options
Diffstat (limited to 'jwks_fetcher.go')
-rw-r--r-- | jwks_fetcher.go | 118 |
1 files changed, 118 insertions, 0 deletions
diff --git a/jwks_fetcher.go b/jwks_fetcher.go new file mode 100644 index 0000000..9925430 --- /dev/null +++ b/jwks_fetcher.go | |||
@@ -0,0 +1,118 @@ | |||
1 | package main | ||
2 | |||
3 | import ( | ||
4 | "github.com/pkg/errors" | ||
5 | "gopkg.in/square/go-jose.v2" | ||
6 | "log" | ||
7 | "net/url" | ||
8 | "time" | ||
9 | ) | ||
10 | |||
11 | const ( | ||
12 | REQUEST_BUFFER_SIZE = 10 | ||
13 | KEY_MAP_INITIAL_SIZE = 5 | ||
14 | DEFAULT_REFRESH_INTERVAL = 15 * time.Minute | ||
15 | MIN_REFRESH_INTERVAL = 1 * time.Minute | ||
16 | ) | ||
17 | |||
18 | type KeyRequest struct { | ||
19 | KeyId string | ||
20 | Response chan *jose.JSONWebKey | ||
21 | } | ||
22 | |||
23 | type JWKSFetcher interface { | ||
24 | Run() | ||
25 | Fetch() error | ||
26 | GetKey(string) (*jose.JSONWebKey, error) | ||
27 | Done() | ||
28 | } | ||
29 | |||
30 | type jwksFetcher struct { | ||
31 | keyMap map[string]jose.JSONWebKey | ||
32 | httpClient CautiousHTTPClient | ||
33 | validator KeyValidator | ||
34 | fetchTimer *time.Timer | ||
35 | url *url.URL | ||
36 | requests chan *KeyRequest | ||
37 | done chan bool | ||
38 | } | ||
39 | |||
40 | func NewJWKSFetcher(h CautiousHTTPClient, url *url.URL, issuer string, root string) JWKSFetcher { | ||
41 | val := NewKeyValidator(HostFromURL(issuer)) | ||
42 | val.LoadRootPEM(root) | ||
43 | |||
44 | return &jwksFetcher{ | ||
45 | httpClient: h, | ||
46 | validator: val, | ||
47 | url: url, | ||
48 | fetchTimer: time.NewTimer(DEFAULT_REFRESH_INTERVAL), | ||
49 | requests: make(chan *KeyRequest, REQUEST_BUFFER_SIZE), | ||
50 | keyMap: make(map[string]jose.JSONWebKey, KEY_MAP_INITIAL_SIZE), | ||
51 | done: make(chan bool), | ||
52 | } | ||
53 | } | ||
54 | |||
55 | func (f *jwksFetcher) Fetch() error { | ||
56 | var jwks jose.JSONWebKeySet | ||
57 | timeout, err := f.httpClient.GetJSONExpires(f.url.String(), &jwks) | ||
58 | if err != nil { | ||
59 | return errors.WithStack(err) | ||
60 | } | ||
61 | |||
62 | for _, k := range jwks.Keys { | ||
63 | err = f.validator.Validate(k) | ||
64 | if err == nil { | ||
65 | f.keyMap[k.KeyID] = k | ||
66 | } else { | ||
67 | log.Printf("Rejecting key %q because %q", k.KeyID, err) | ||
68 | } | ||
69 | } | ||
70 | |||
71 | if timeout < MIN_REFRESH_INTERVAL { | ||
72 | timeout = MIN_REFRESH_INTERVAL | ||
73 | } | ||
74 | |||
75 | success := f.fetchTimer.Reset(timeout) | ||
76 | if !success { | ||
77 | f.fetchTimer = time.NewTimer(timeout) | ||
78 | } | ||
79 | |||
80 | return nil | ||
81 | } | ||
82 | |||
83 | func (f *jwksFetcher) Run() { | ||
84 | for { | ||
85 | select { | ||
86 | // Incoming request for a key, return key or nil in no key | ||
87 | case r := <-f.requests: | ||
88 | if v, ok := f.keyMap[r.KeyId]; ok { | ||
89 | r.Response <- &v | ||
90 | } else { | ||
91 | r.Response <- nil | ||
92 | } | ||
93 | case <-f.fetchTimer.C: | ||
94 | f.Fetch() | ||
95 | case <-f.done: | ||
96 | return | ||
97 | } | ||
98 | } | ||
99 | } | ||
100 | |||
101 | func (f *jwksFetcher) Done() { | ||
102 | f.done <- true | ||
103 | } | ||
104 | |||
105 | func (f *jwksFetcher) GetKey(kid string) (*jose.JSONWebKey, error) { | ||
106 | r := &KeyRequest{ | ||
107 | KeyId: kid, | ||
108 | Response: make(chan *jose.JSONWebKey), | ||
109 | } | ||
110 | |||
111 | f.requests <- r | ||
112 | |||
113 | if res := <-r.Response; res == nil { | ||
114 | return nil, errors.Errorf("Key not found for ID") | ||
115 | } else { | ||
116 | return res, nil | ||
117 | } | ||
118 | } | ||