aboutsummaryrefslogtreecommitdiff
path: root/secrets/vault_client.go
diff options
context:
space:
mode:
Diffstat (limited to 'secrets/vault_client.go')
-rw-r--r--secrets/vault_client.go426
1 files changed, 426 insertions, 0 deletions
diff --git a/secrets/vault_client.go b/secrets/vault_client.go
new file mode 100644
index 0000000..3466f48
--- /dev/null
+++ b/secrets/vault_client.go
@@ -0,0 +1,426 @@
1package secrets
2
3import (
4 "container/heap"
5 "context"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "net/http"
10 "path"
11 "sync"
12 "time"
13
14 glos "code.crute.us/mcrute/golib/os"
15
16 "github.com/hashicorp/vault/api"
17 "github.com/hashicorp/vault/api/auth/approle"
18 "github.com/mitchellh/mapstructure"
19)
20
21const (
22 renewalStartPercent = 0.8
23 notificationChanLen = 100
24 defaultIncrement = 30 * 60 // 30 minutes (as seconds)
25 defaultTimeout = 10 * time.Second
26 renewalWindow = 30 * time.Second
27)
28
29type vaultRenewalMinHeap []*VaultHandle
30
31func (h vaultRenewalMinHeap) Len() int { return len(h) }
32func (h vaultRenewalMinHeap) Less(i, j int) bool { return h[i].renewAfter() < h[j].renewAfter() }
33func (h vaultRenewalMinHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
34func (h *vaultRenewalMinHeap) Push(x any) { *h = append(*h, x.(*VaultHandle)) }
35func (h vaultRenewalMinHeap) Root() *VaultHandle { return h[0] }
36
37// Convenience methods to hide the collections/heap stuff from users
38func (h *vaultRenewalMinHeap) PopHeap() *VaultHandle { return heap.Pop(h).(*VaultHandle) }
39func (h *vaultRenewalMinHeap) PushHeap(i *VaultHandle) { heap.Push(h, i) }
40func (h *vaultRenewalMinHeap) Init() { heap.Init(h) }
41
42func (h *vaultRenewalMinHeap) Pop() any {
43 old := *h
44 n := len(old)
45 item := old[n-1]
46 old[n-1] = nil // avoid memory leak
47 *h = old[0 : n-1]
48 return item
49}
50
51func (h *vaultRenewalMinHeap) FindRemove(handle *VaultHandle) bool {
52 for i, hnd := range *h {
53 if hnd.name == handle.name {
54 heap.Remove(h, i)
55 return true
56 }
57 }
58 return false
59}
60
61type VaultHandle struct {
62 name string
63 critical bool
64 acquired time.Time
65 secret *api.Secret
66}
67
68var _ Handle = (*VaultHandle)(nil)
69
70func (h *VaultHandle) isAuthToken() bool {
71 return h.secret.Auth != nil
72}
73
74func (h *VaultHandle) renewAfter() time.Duration {
75 after := float64(h.leaseDuration().Nanoseconds()) * renewalStartPercent
76 afterTime := h.acquired.Add(time.Duration(after))
77 return afterTime.Sub(time.Now()).Round(time.Second)
78}
79
80func (h *VaultHandle) leaseDuration() time.Duration {
81 duration := time.Duration(h.secret.LeaseDuration) * time.Second
82 if h.isAuthToken() {
83 return time.Duration(h.secret.Auth.LeaseDuration) * time.Second
84 }
85 return duration
86}
87
88func (h *VaultHandle) renew(ctx context.Context, c VaultServiceClient, inc int) (err error) {
89 var s *api.Secret
90
91 vctx, cancel := context.WithTimeout(ctx, defaultTimeout)
92 defer cancel()
93
94 if h.isAuthToken() {
95 s, err = c.Auth().Token().RenewTokenAsSelfWithContext(vctx, h.secret.Auth.ClientToken, inc)
96 if err != nil {
97 return err
98 }
99 } else {
100 s, err = c.Sys().RenewWithContext(vctx, h.secret.LeaseID, inc)
101 if err != nil {
102 return err
103 }
104 }
105
106 h.secret = s
107 h.acquired = time.Now()
108
109 return nil
110}
111
112func (h *VaultHandle) Reference() string {
113 return h.name
114}
115
116type VaultServiceClient interface {
117 Auth() *api.Auth
118 Sys() *api.Sys
119 Token() string
120}
121
122type VaultClient struct {
123 sync.Mutex
124
125 client VaultServiceClient
126 logical *api.Logical
127 auth *approle.AppRoleAuth
128 secrets vaultRenewalMinHeap
129 renewIncrement int
130 notifications chan Renewal
131}
132
133var _ Client = (*VaultClient)(nil)
134var _ ClientManager = (*VaultClient)(nil)
135
136type VaultClientConfig struct {
137 Host string `env:"VAULT_ADDR"`
138 Token string `env:"VAULT_TOKEN"`
139 RoleId string `env:"VAULT_ROLE_ID"`
140 RoleSecret string `env:"VAULT_SECRET_ID"`
141 Increment int `env:"VAULT_INCREMENT"`
142 AppRoleAuth *approle.AppRoleAuth
143}
144
145func (c *VaultClientConfig) Validate() error {
146 if c.Host == "" {
147 return fmt.Errorf("VaultClientConfig: Vault host is not specified")
148 }
149
150 // The presence of a token is always assumed to be valid, client errors
151 // will occur otherwise.
152 if c.Token != "" {
153 return nil
154 }
155
156 // This constructor does a bunch of validation internally so just let
157 // it do its thing and return any errors from that directly to the
158 // user.
159 ar, err := approle.NewAppRoleAuth(c.RoleId, &approle.SecretID{FromString: c.RoleSecret})
160 if err != nil {
161 return fmt.Errorf("VaultClientConfig: AppRole credentials invalid: %w", err)
162 }
163 c.AppRoleAuth = ar
164
165 return nil
166}
167
168// NewVaultClient will attempt to create a secrets.Client from the
169// passed config. Config can be nil, in which case an attempt will
170// be made to load the configuration from environment variables. See
171// VaultClientConfig for the expected names of those variables.
172func NewVaultClient(cfg *VaultClientConfig) (ClientManager, error) {
173 if cfg == nil {
174 cfg = &VaultClientConfig{}
175 }
176
177 if err := glos.UnmarshalEnvironment(cfg); err != nil {
178 return nil, err
179 }
180
181 if err := cfg.Validate(); err != nil {
182 return nil, err
183 }
184
185 vc, err := api.NewClient(api.DefaultConfig())
186 if err != nil {
187 return nil, fmt.Errorf("NewVaultClient: error building client config: %w", err)
188 }
189 vc.SetAddress(cfg.Host)
190
191 if cfg.Token != "" {
192 vc.SetToken(cfg.Token)
193 }
194
195 c := &VaultClient{
196 client: vc,
197 logical: vc.Logical(),
198 secrets: vaultRenewalMinHeap{},
199 notifications: make(chan Renewal, notificationChanLen),
200 auth: cfg.AppRoleAuth,
201 renewIncrement: cfg.Increment,
202 }
203
204 c.secrets.Init()
205
206 if c.renewIncrement == 0 {
207 c.renewIncrement = defaultIncrement
208 }
209
210 return c, nil
211}
212
213func (c *VaultClient) Notifications() <-chan Renewal {
214 return c.notifications
215}
216
217func (c *VaultClient) Authenticate(ctx context.Context) error {
218 if c.auth == nil {
219 return c.authToken(ctx)
220 } else {
221 return c.authAppRole(ctx)
222 }
223}
224
225func (c *VaultClient) authToken(ctx context.Context) error {
226 if c.client.Token() == "" {
227 return fmt.Errorf("Authenticate: unable to authenticate, neither token nor approle provided")
228 }
229
230 vctx, cancel := context.WithTimeout(ctx, defaultTimeout)
231 defer cancel()
232
233 secret, err := c.client.Auth().Token().LookupSelfWithContext(vctx)
234 if err != nil {
235 return err
236 }
237
238 // Looking up self does not return an auth token just a map of data
239 // about the current token. Convert this into a SecretAuth so that
240 // downstream renewal code does the right thing.
241 secret.Auth = &api.SecretAuth{}
242 if err := mapstructure.Decode(secret.Data, secret.Auth); err != nil {
243 return err
244 }
245 secret.Auth.ClientToken = c.client.Token()
246
247 c.makeHandle("login", secret)
248
249 return nil
250}
251
252func (c *VaultClient) authAppRole(ctx context.Context) error {
253 vctx, cancel := context.WithTimeout(ctx, defaultTimeout)
254 defer cancel()
255
256 s, err := c.client.Auth().Login(vctx, c.auth)
257 if err != nil {
258 return fmt.Errorf("Authenticate: error logging in to vault: %w", err)
259 }
260 c.makeHandle("login", s)
261 return err
262}
263
264// makeHandle creates a secret handle and schedules it for renewal if it
265// is renewable.
266func (c *VaultClient) makeHandle(name string, s *api.Secret) Handle {
267 h := &VaultHandle{
268 name: name,
269 critical: true, // Everything is critical unless marked otherwise
270 acquired: time.Now(),
271 secret: s,
272 }
273
274 // If this is renewable then schedule it for renewal
275 if (s.Auth != nil && s.Auth.Renewable) || s.Renewable {
276 c.Lock()
277 c.secrets.PushHeap(h)
278 c.Unlock()
279 }
280
281 return h
282}
283
284func (c *VaultClient) read(ctx context.Context, prefix, suffix string) (Handle, error) {
285 key := path.Join(prefix, suffix)
286
287 s, err := c.logical.ReadWithContext(ctx, key)
288 if err != nil {
289 return nil, fmt.Errorf("read: error reading from Vault: %w", err)
290 }
291
292 return c.makeHandle(key, s), nil
293}
294
295func (c *VaultClient) isRecoverableDbError(err error) bool {
296 var apiErr *api.ResponseError
297 return errors.Is(api.ErrSecretNotFound, err) ||
298 (errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusForbidden)
299}
300
301func (c *VaultClient) DatabaseCredential(ctx context.Context, suffix string) (*Credential, Handle, error) {
302 cred, hnd, err := c.databaseCredentialDynamic(ctx, suffix)
303 if err != nil {
304 if c.isRecoverableDbError(err) {
305 cred, hnd, err = c.databaseCredentialStatic(ctx, suffix)
306 }
307 }
308 return cred, hnd, err
309}
310
311func (c *VaultClient) databaseCredentialDynamic(ctx context.Context, suffix string) (*Credential, Handle, error) {
312 h, err := c.read(ctx, "database/creds", suffix)
313 if err != nil {
314 return nil, nil, err
315 }
316 vh := h.(*VaultHandle)
317
318 var d Credential
319 if err = mapstructure.Decode(vh.secret.Data, &d); err != nil {
320 return nil, nil, fmt.Errorf("databaseCredentialStatic: error decoding secret: %w", err)
321 }
322
323 return &d, h, nil
324}
325
326func (c *VaultClient) databaseCredentialStatic(ctx context.Context, suffix string) (*Credential, Handle, error) {
327 h, err := c.read(ctx, "database/static-creds", suffix)
328 if err != nil {
329 return nil, nil, err
330 }
331
332 var d Credential
333 if err = mapstructure.Decode(h.(*VaultHandle).secret.Data, &d); err != nil {
334 return nil, nil, fmt.Errorf("databaseCredentialStatic: error decoding secret: %w", err)
335 }
336
337 return &d, h, nil
338}
339
340func (c *VaultClient) Secret(ctx context.Context, suffix string, out any) (Handle, error) {
341 h, err := c.read(ctx, "kv/data", suffix)
342 if err != nil {
343 return nil, err
344 }
345
346 if err = mapstructure.Decode(h.(*VaultHandle).secret.Data["data"], out); err != nil {
347 return nil, err
348 }
349
350 return h, nil
351}
352
353func (c *VaultClient) WriteSecret(ctx context.Context, suffix string, in any) error {
354 inb, err := json.Marshal(in)
355 if err != nil {
356 return fmt.Errorf("WriteSecret: error encoding json: %w", err)
357 }
358
359 if _, err = c.logical.WriteBytesWithContext(ctx, path.Join("kv/data", suffix), inb); err != nil {
360 return fmt.Errorf("WriteSecret: error writing to vault: %w", err)
361 }
362
363 return nil
364}
365
366func (c *VaultClient) Destroy(h Handle) error {
367 c.Lock()
368 defer c.Unlock()
369
370 c.secrets.FindRemove(h.(*VaultHandle))
371
372 return nil
373}
374
375func (c *VaultClient) MakeNonCritical(h Handle) error {
376 h.(*VaultHandle).critical = false
377 return nil
378}
379
380func (c *VaultClient) renewAttempt(ctx context.Context) (next time.Duration) {
381 c.Lock()
382 defer c.Unlock()
383
384 // In the absence of any other time, run once a second
385 next = 5 * time.Second
386
387 if c.secrets.Len() < 1 {
388 return
389 }
390
391 for {
392 s := c.secrets.PopHeap()
393 if s.renewAfter() < renewalWindow {
394 // Underlying client does backoff and retry
395 c.notifications <- Renewal{
396 Name: s.name,
397 Critical: s.critical,
398 Time: time.Now(),
399 Error: s.renew(ctx, c.client, c.renewIncrement),
400 }
401 c.secrets.PushHeap(s)
402 } else {
403 c.secrets.PushHeap(s)
404 next = c.secrets.Root().renewAfter()
405 break
406 }
407 }
408
409 return next
410}
411
412func (c *VaultClient) Run(ctx context.Context, wg *sync.WaitGroup) error {
413 wg.Add(1)
414 defer wg.Done()
415
416 for {
417 sleepTime := c.renewAttempt(ctx)
418
419 select {
420 case <-time.After(sleepTime):
421 continue
422 case <-ctx.Done():
423 return nil
424 }
425 }
426}