diff options
Diffstat (limited to 'secrets/vault_client.go')
-rw-r--r-- | secrets/vault_client.go | 426 |
1 files changed, 426 insertions, 0 deletions
diff --git a/secrets/vault_client.go b/secrets/vault_client.go new file mode 100644 index 0000000..3466f48 --- /dev/null +++ b/secrets/vault_client.go | |||
@@ -0,0 +1,426 @@ | |||
1 | package secrets | ||
2 | |||
3 | import ( | ||
4 | "container/heap" | ||
5 | "context" | ||
6 | "encoding/json" | ||
7 | "errors" | ||
8 | "fmt" | ||
9 | "net/http" | ||
10 | "path" | ||
11 | "sync" | ||
12 | "time" | ||
13 | |||
14 | glos "code.crute.us/mcrute/golib/os" | ||
15 | |||
16 | "github.com/hashicorp/vault/api" | ||
17 | "github.com/hashicorp/vault/api/auth/approle" | ||
18 | "github.com/mitchellh/mapstructure" | ||
19 | ) | ||
20 | |||
21 | const ( | ||
22 | renewalStartPercent = 0.8 | ||
23 | notificationChanLen = 100 | ||
24 | defaultIncrement = 30 * 60 // 30 minutes (as seconds) | ||
25 | defaultTimeout = 10 * time.Second | ||
26 | renewalWindow = 30 * time.Second | ||
27 | ) | ||
28 | |||
29 | type vaultRenewalMinHeap []*VaultHandle | ||
30 | |||
31 | func (h vaultRenewalMinHeap) Len() int { return len(h) } | ||
32 | func (h vaultRenewalMinHeap) Less(i, j int) bool { return h[i].renewAfter() < h[j].renewAfter() } | ||
33 | func (h vaultRenewalMinHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } | ||
34 | func (h *vaultRenewalMinHeap) Push(x any) { *h = append(*h, x.(*VaultHandle)) } | ||
35 | func (h vaultRenewalMinHeap) Root() *VaultHandle { return h[0] } | ||
36 | |||
37 | // Convenience methods to hide the collections/heap stuff from users | ||
38 | func (h *vaultRenewalMinHeap) PopHeap() *VaultHandle { return heap.Pop(h).(*VaultHandle) } | ||
39 | func (h *vaultRenewalMinHeap) PushHeap(i *VaultHandle) { heap.Push(h, i) } | ||
40 | func (h *vaultRenewalMinHeap) Init() { heap.Init(h) } | ||
41 | |||
42 | func (h *vaultRenewalMinHeap) Pop() any { | ||
43 | old := *h | ||
44 | n := len(old) | ||
45 | item := old[n-1] | ||
46 | old[n-1] = nil // avoid memory leak | ||
47 | *h = old[0 : n-1] | ||
48 | return item | ||
49 | } | ||
50 | |||
51 | func (h *vaultRenewalMinHeap) FindRemove(handle *VaultHandle) bool { | ||
52 | for i, hnd := range *h { | ||
53 | if hnd.name == handle.name { | ||
54 | heap.Remove(h, i) | ||
55 | return true | ||
56 | } | ||
57 | } | ||
58 | return false | ||
59 | } | ||
60 | |||
61 | type VaultHandle struct { | ||
62 | name string | ||
63 | critical bool | ||
64 | acquired time.Time | ||
65 | secret *api.Secret | ||
66 | } | ||
67 | |||
68 | var _ Handle = (*VaultHandle)(nil) | ||
69 | |||
70 | func (h *VaultHandle) isAuthToken() bool { | ||
71 | return h.secret.Auth != nil | ||
72 | } | ||
73 | |||
74 | func (h *VaultHandle) renewAfter() time.Duration { | ||
75 | after := float64(h.leaseDuration().Nanoseconds()) * renewalStartPercent | ||
76 | afterTime := h.acquired.Add(time.Duration(after)) | ||
77 | return afterTime.Sub(time.Now()).Round(time.Second) | ||
78 | } | ||
79 | |||
80 | func (h *VaultHandle) leaseDuration() time.Duration { | ||
81 | duration := time.Duration(h.secret.LeaseDuration) * time.Second | ||
82 | if h.isAuthToken() { | ||
83 | return time.Duration(h.secret.Auth.LeaseDuration) * time.Second | ||
84 | } | ||
85 | return duration | ||
86 | } | ||
87 | |||
88 | func (h *VaultHandle) renew(ctx context.Context, c VaultServiceClient, inc int) (err error) { | ||
89 | var s *api.Secret | ||
90 | |||
91 | vctx, cancel := context.WithTimeout(ctx, defaultTimeout) | ||
92 | defer cancel() | ||
93 | |||
94 | if h.isAuthToken() { | ||
95 | s, err = c.Auth().Token().RenewTokenAsSelfWithContext(vctx, h.secret.Auth.ClientToken, inc) | ||
96 | if err != nil { | ||
97 | return err | ||
98 | } | ||
99 | } else { | ||
100 | s, err = c.Sys().RenewWithContext(vctx, h.secret.LeaseID, inc) | ||
101 | if err != nil { | ||
102 | return err | ||
103 | } | ||
104 | } | ||
105 | |||
106 | h.secret = s | ||
107 | h.acquired = time.Now() | ||
108 | |||
109 | return nil | ||
110 | } | ||
111 | |||
112 | func (h *VaultHandle) Reference() string { | ||
113 | return h.name | ||
114 | } | ||
115 | |||
116 | type VaultServiceClient interface { | ||
117 | Auth() *api.Auth | ||
118 | Sys() *api.Sys | ||
119 | Token() string | ||
120 | } | ||
121 | |||
122 | type VaultClient struct { | ||
123 | sync.Mutex | ||
124 | |||
125 | client VaultServiceClient | ||
126 | logical *api.Logical | ||
127 | auth *approle.AppRoleAuth | ||
128 | secrets vaultRenewalMinHeap | ||
129 | renewIncrement int | ||
130 | notifications chan Renewal | ||
131 | } | ||
132 | |||
133 | var _ Client = (*VaultClient)(nil) | ||
134 | var _ ClientManager = (*VaultClient)(nil) | ||
135 | |||
136 | type VaultClientConfig struct { | ||
137 | Host string `env:"VAULT_ADDR"` | ||
138 | Token string `env:"VAULT_TOKEN"` | ||
139 | RoleId string `env:"VAULT_ROLE_ID"` | ||
140 | RoleSecret string `env:"VAULT_SECRET_ID"` | ||
141 | Increment int `env:"VAULT_INCREMENT"` | ||
142 | AppRoleAuth *approle.AppRoleAuth | ||
143 | } | ||
144 | |||
145 | func (c *VaultClientConfig) Validate() error { | ||
146 | if c.Host == "" { | ||
147 | return fmt.Errorf("VaultClientConfig: Vault host is not specified") | ||
148 | } | ||
149 | |||
150 | // The presence of a token is always assumed to be valid, client errors | ||
151 | // will occur otherwise. | ||
152 | if c.Token != "" { | ||
153 | return nil | ||
154 | } | ||
155 | |||
156 | // This constructor does a bunch of validation internally so just let | ||
157 | // it do its thing and return any errors from that directly to the | ||
158 | // user. | ||
159 | ar, err := approle.NewAppRoleAuth(c.RoleId, &approle.SecretID{FromString: c.RoleSecret}) | ||
160 | if err != nil { | ||
161 | return fmt.Errorf("VaultClientConfig: AppRole credentials invalid: %w", err) | ||
162 | } | ||
163 | c.AppRoleAuth = ar | ||
164 | |||
165 | return nil | ||
166 | } | ||
167 | |||
168 | // NewVaultClient will attempt to create a secrets.Client from the | ||
169 | // passed config. Config can be nil, in which case an attempt will | ||
170 | // be made to load the configuration from environment variables. See | ||
171 | // VaultClientConfig for the expected names of those variables. | ||
172 | func NewVaultClient(cfg *VaultClientConfig) (ClientManager, error) { | ||
173 | if cfg == nil { | ||
174 | cfg = &VaultClientConfig{} | ||
175 | } | ||
176 | |||
177 | if err := glos.UnmarshalEnvironment(cfg); err != nil { | ||
178 | return nil, err | ||
179 | } | ||
180 | |||
181 | if err := cfg.Validate(); err != nil { | ||
182 | return nil, err | ||
183 | } | ||
184 | |||
185 | vc, err := api.NewClient(api.DefaultConfig()) | ||
186 | if err != nil { | ||
187 | return nil, fmt.Errorf("NewVaultClient: error building client config: %w", err) | ||
188 | } | ||
189 | vc.SetAddress(cfg.Host) | ||
190 | |||
191 | if cfg.Token != "" { | ||
192 | vc.SetToken(cfg.Token) | ||
193 | } | ||
194 | |||
195 | c := &VaultClient{ | ||
196 | client: vc, | ||
197 | logical: vc.Logical(), | ||
198 | secrets: vaultRenewalMinHeap{}, | ||
199 | notifications: make(chan Renewal, notificationChanLen), | ||
200 | auth: cfg.AppRoleAuth, | ||
201 | renewIncrement: cfg.Increment, | ||
202 | } | ||
203 | |||
204 | c.secrets.Init() | ||
205 | |||
206 | if c.renewIncrement == 0 { | ||
207 | c.renewIncrement = defaultIncrement | ||
208 | } | ||
209 | |||
210 | return c, nil | ||
211 | } | ||
212 | |||
213 | func (c *VaultClient) Notifications() <-chan Renewal { | ||
214 | return c.notifications | ||
215 | } | ||
216 | |||
217 | func (c *VaultClient) Authenticate(ctx context.Context) error { | ||
218 | if c.auth == nil { | ||
219 | return c.authToken(ctx) | ||
220 | } else { | ||
221 | return c.authAppRole(ctx) | ||
222 | } | ||
223 | } | ||
224 | |||
225 | func (c *VaultClient) authToken(ctx context.Context) error { | ||
226 | if c.client.Token() == "" { | ||
227 | return fmt.Errorf("Authenticate: unable to authenticate, neither token nor approle provided") | ||
228 | } | ||
229 | |||
230 | vctx, cancel := context.WithTimeout(ctx, defaultTimeout) | ||
231 | defer cancel() | ||
232 | |||
233 | secret, err := c.client.Auth().Token().LookupSelfWithContext(vctx) | ||
234 | if err != nil { | ||
235 | return err | ||
236 | } | ||
237 | |||
238 | // Looking up self does not return an auth token just a map of data | ||
239 | // about the current token. Convert this into a SecretAuth so that | ||
240 | // downstream renewal code does the right thing. | ||
241 | secret.Auth = &api.SecretAuth{} | ||
242 | if err := mapstructure.Decode(secret.Data, secret.Auth); err != nil { | ||
243 | return err | ||
244 | } | ||
245 | secret.Auth.ClientToken = c.client.Token() | ||
246 | |||
247 | c.makeHandle("login", secret) | ||
248 | |||
249 | return nil | ||
250 | } | ||
251 | |||
252 | func (c *VaultClient) authAppRole(ctx context.Context) error { | ||
253 | vctx, cancel := context.WithTimeout(ctx, defaultTimeout) | ||
254 | defer cancel() | ||
255 | |||
256 | s, err := c.client.Auth().Login(vctx, c.auth) | ||
257 | if err != nil { | ||
258 | return fmt.Errorf("Authenticate: error logging in to vault: %w", err) | ||
259 | } | ||
260 | c.makeHandle("login", s) | ||
261 | return err | ||
262 | } | ||
263 | |||
264 | // makeHandle creates a secret handle and schedules it for renewal if it | ||
265 | // is renewable. | ||
266 | func (c *VaultClient) makeHandle(name string, s *api.Secret) Handle { | ||
267 | h := &VaultHandle{ | ||
268 | name: name, | ||
269 | critical: true, // Everything is critical unless marked otherwise | ||
270 | acquired: time.Now(), | ||
271 | secret: s, | ||
272 | } | ||
273 | |||
274 | // If this is renewable then schedule it for renewal | ||
275 | if (s.Auth != nil && s.Auth.Renewable) || s.Renewable { | ||
276 | c.Lock() | ||
277 | c.secrets.PushHeap(h) | ||
278 | c.Unlock() | ||
279 | } | ||
280 | |||
281 | return h | ||
282 | } | ||
283 | |||
284 | func (c *VaultClient) read(ctx context.Context, prefix, suffix string) (Handle, error) { | ||
285 | key := path.Join(prefix, suffix) | ||
286 | |||
287 | s, err := c.logical.ReadWithContext(ctx, key) | ||
288 | if err != nil { | ||
289 | return nil, fmt.Errorf("read: error reading from Vault: %w", err) | ||
290 | } | ||
291 | |||
292 | return c.makeHandle(key, s), nil | ||
293 | } | ||
294 | |||
295 | func (c *VaultClient) isRecoverableDbError(err error) bool { | ||
296 | var apiErr *api.ResponseError | ||
297 | return errors.Is(api.ErrSecretNotFound, err) || | ||
298 | (errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusForbidden) | ||
299 | } | ||
300 | |||
301 | func (c *VaultClient) DatabaseCredential(ctx context.Context, suffix string) (*Credential, Handle, error) { | ||
302 | cred, hnd, err := c.databaseCredentialDynamic(ctx, suffix) | ||
303 | if err != nil { | ||
304 | if c.isRecoverableDbError(err) { | ||
305 | cred, hnd, err = c.databaseCredentialStatic(ctx, suffix) | ||
306 | } | ||
307 | } | ||
308 | return cred, hnd, err | ||
309 | } | ||
310 | |||
311 | func (c *VaultClient) databaseCredentialDynamic(ctx context.Context, suffix string) (*Credential, Handle, error) { | ||
312 | h, err := c.read(ctx, "database/creds", suffix) | ||
313 | if err != nil { | ||
314 | return nil, nil, err | ||
315 | } | ||
316 | vh := h.(*VaultHandle) | ||
317 | |||
318 | var d Credential | ||
319 | if err = mapstructure.Decode(vh.secret.Data, &d); err != nil { | ||
320 | return nil, nil, fmt.Errorf("databaseCredentialStatic: error decoding secret: %w", err) | ||
321 | } | ||
322 | |||
323 | return &d, h, nil | ||
324 | } | ||
325 | |||
326 | func (c *VaultClient) databaseCredentialStatic(ctx context.Context, suffix string) (*Credential, Handle, error) { | ||
327 | h, err := c.read(ctx, "database/static-creds", suffix) | ||
328 | if err != nil { | ||
329 | return nil, nil, err | ||
330 | } | ||
331 | |||
332 | var d Credential | ||
333 | if err = mapstructure.Decode(h.(*VaultHandle).secret.Data, &d); err != nil { | ||
334 | return nil, nil, fmt.Errorf("databaseCredentialStatic: error decoding secret: %w", err) | ||
335 | } | ||
336 | |||
337 | return &d, h, nil | ||
338 | } | ||
339 | |||
340 | func (c *VaultClient) Secret(ctx context.Context, suffix string, out any) (Handle, error) { | ||
341 | h, err := c.read(ctx, "kv/data", suffix) | ||
342 | if err != nil { | ||
343 | return nil, err | ||
344 | } | ||
345 | |||
346 | if err = mapstructure.Decode(h.(*VaultHandle).secret.Data["data"], out); err != nil { | ||
347 | return nil, err | ||
348 | } | ||
349 | |||
350 | return h, nil | ||
351 | } | ||
352 | |||
353 | func (c *VaultClient) WriteSecret(ctx context.Context, suffix string, in any) error { | ||
354 | inb, err := json.Marshal(in) | ||
355 | if err != nil { | ||
356 | return fmt.Errorf("WriteSecret: error encoding json: %w", err) | ||
357 | } | ||
358 | |||
359 | if _, err = c.logical.WriteBytesWithContext(ctx, path.Join("kv/data", suffix), inb); err != nil { | ||
360 | return fmt.Errorf("WriteSecret: error writing to vault: %w", err) | ||
361 | } | ||
362 | |||
363 | return nil | ||
364 | } | ||
365 | |||
366 | func (c *VaultClient) Destroy(h Handle) error { | ||
367 | c.Lock() | ||
368 | defer c.Unlock() | ||
369 | |||
370 | c.secrets.FindRemove(h.(*VaultHandle)) | ||
371 | |||
372 | return nil | ||
373 | } | ||
374 | |||
375 | func (c *VaultClient) MakeNonCritical(h Handle) error { | ||
376 | h.(*VaultHandle).critical = false | ||
377 | return nil | ||
378 | } | ||
379 | |||
380 | func (c *VaultClient) renewAttempt(ctx context.Context) (next time.Duration) { | ||
381 | c.Lock() | ||
382 | defer c.Unlock() | ||
383 | |||
384 | // In the absence of any other time, run once a second | ||
385 | next = 5 * time.Second | ||
386 | |||
387 | if c.secrets.Len() < 1 { | ||
388 | return | ||
389 | } | ||
390 | |||
391 | for { | ||
392 | s := c.secrets.PopHeap() | ||
393 | if s.renewAfter() < renewalWindow { | ||
394 | // Underlying client does backoff and retry | ||
395 | c.notifications <- Renewal{ | ||
396 | Name: s.name, | ||
397 | Critical: s.critical, | ||
398 | Time: time.Now(), | ||
399 | Error: s.renew(ctx, c.client, c.renewIncrement), | ||
400 | } | ||
401 | c.secrets.PushHeap(s) | ||
402 | } else { | ||
403 | c.secrets.PushHeap(s) | ||
404 | next = c.secrets.Root().renewAfter() | ||
405 | break | ||
406 | } | ||
407 | } | ||
408 | |||
409 | return next | ||
410 | } | ||
411 | |||
412 | func (c *VaultClient) Run(ctx context.Context, wg *sync.WaitGroup) error { | ||
413 | wg.Add(1) | ||
414 | defer wg.Done() | ||
415 | |||
416 | for { | ||
417 | sleepTime := c.renewAttempt(ctx) | ||
418 | |||
419 | select { | ||
420 | case <-time.After(sleepTime): | ||
421 | continue | ||
422 | case <-ctx.Done(): | ||
423 | return nil | ||
424 | } | ||
425 | } | ||
426 | } | ||