summaryrefslogtreecommitdiff
path: root/jwks_fetcher.go
diff options
context:
space:
mode:
Diffstat (limited to 'jwks_fetcher.go')
-rw-r--r--jwks_fetcher.go118
1 files changed, 118 insertions, 0 deletions
diff --git a/jwks_fetcher.go b/jwks_fetcher.go
new file mode 100644
index 0000000..9925430
--- /dev/null
+++ b/jwks_fetcher.go
@@ -0,0 +1,118 @@
1package main
2
3import (
4 "github.com/pkg/errors"
5 "gopkg.in/square/go-jose.v2"
6 "log"
7 "net/url"
8 "time"
9)
10
11const (
12 REQUEST_BUFFER_SIZE = 10
13 KEY_MAP_INITIAL_SIZE = 5
14 DEFAULT_REFRESH_INTERVAL = 15 * time.Minute
15 MIN_REFRESH_INTERVAL = 1 * time.Minute
16)
17
18type KeyRequest struct {
19 KeyId string
20 Response chan *jose.JSONWebKey
21}
22
23type JWKSFetcher interface {
24 Run()
25 Fetch() error
26 GetKey(string) (*jose.JSONWebKey, error)
27 Done()
28}
29
30type jwksFetcher struct {
31 keyMap map[string]jose.JSONWebKey
32 httpClient CautiousHTTPClient
33 validator KeyValidator
34 fetchTimer *time.Timer
35 url *url.URL
36 requests chan *KeyRequest
37 done chan bool
38}
39
40func NewJWKSFetcher(h CautiousHTTPClient, url *url.URL, issuer string, root string) JWKSFetcher {
41 val := NewKeyValidator(HostFromURL(issuer))
42 val.LoadRootPEM(root)
43
44 return &jwksFetcher{
45 httpClient: h,
46 validator: val,
47 url: url,
48 fetchTimer: time.NewTimer(DEFAULT_REFRESH_INTERVAL),
49 requests: make(chan *KeyRequest, REQUEST_BUFFER_SIZE),
50 keyMap: make(map[string]jose.JSONWebKey, KEY_MAP_INITIAL_SIZE),
51 done: make(chan bool),
52 }
53}
54
55func (f *jwksFetcher) Fetch() error {
56 var jwks jose.JSONWebKeySet
57 timeout, err := f.httpClient.GetJSONExpires(f.url.String(), &jwks)
58 if err != nil {
59 return errors.WithStack(err)
60 }
61
62 for _, k := range jwks.Keys {
63 err = f.validator.Validate(k)
64 if err == nil {
65 f.keyMap[k.KeyID] = k
66 } else {
67 log.Printf("Rejecting key %q because %q", k.KeyID, err)
68 }
69 }
70
71 if timeout < MIN_REFRESH_INTERVAL {
72 timeout = MIN_REFRESH_INTERVAL
73 }
74
75 success := f.fetchTimer.Reset(timeout)
76 if !success {
77 f.fetchTimer = time.NewTimer(timeout)
78 }
79
80 return nil
81}
82
83func (f *jwksFetcher) Run() {
84 for {
85 select {
86 // Incoming request for a key, return key or nil in no key
87 case r := <-f.requests:
88 if v, ok := f.keyMap[r.KeyId]; ok {
89 r.Response <- &v
90 } else {
91 r.Response <- nil
92 }
93 case <-f.fetchTimer.C:
94 f.Fetch()
95 case <-f.done:
96 return
97 }
98 }
99}
100
101func (f *jwksFetcher) Done() {
102 f.done <- true
103}
104
105func (f *jwksFetcher) GetKey(kid string) (*jose.JSONWebKey, error) {
106 r := &KeyRequest{
107 KeyId: kid,
108 Response: make(chan *jose.JSONWebKey),
109 }
110
111 f.requests <- r
112
113 if res := <-r.Response; res == nil {
114 return nil, errors.Errorf("Key not found for ID")
115 } else {
116 return res, nil
117 }
118}