package secrets import ( "container/heap" "context" "encoding/base64" "encoding/json" "errors" "fmt" "net/http" "path" "sync" "time" glos "code.crute.us/mcrute/golib/os" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api/auth/approle" "github.com/mitchellh/mapstructure" ) const ( renewalStartPercent = 0.8 notificationChanLen = 100 defaultIncrement = 30 * 60 // 30 minutes (as seconds) defaultTimeout = 10 * time.Second renewalWindow = 30 * time.Second ) type vaultRenewalMinHeap []*VaultHandle func (h vaultRenewalMinHeap) Len() int { return len(h) } func (h vaultRenewalMinHeap) Less(i, j int) bool { return h[i].renewAfter() < h[j].renewAfter() } func (h vaultRenewalMinHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *vaultRenewalMinHeap) Push(x any) { *h = append(*h, x.(*VaultHandle)) } func (h vaultRenewalMinHeap) Root() *VaultHandle { return h[0] } // Convenience methods to hide the collections/heap stuff from users func (h *vaultRenewalMinHeap) PopHeap() *VaultHandle { return heap.Pop(h).(*VaultHandle) } func (h *vaultRenewalMinHeap) PushHeap(i *VaultHandle) { heap.Push(h, i) } func (h *vaultRenewalMinHeap) Init() { heap.Init(h) } func (h *vaultRenewalMinHeap) Pop() any { old := *h n := len(old) item := old[n-1] old[n-1] = nil // avoid memory leak *h = old[0 : n-1] return item } func (h *vaultRenewalMinHeap) FindRemove(handle *VaultHandle) bool { for i, hnd := range *h { if hnd.name == handle.name { heap.Remove(h, i) return true } } return false } type VaultHandle struct { name string critical bool acquired time.Time secret *api.Secret } var _ Handle = (*VaultHandle)(nil) func (h *VaultHandle) isAuthToken() bool { return h.secret.Auth != nil } func (h *VaultHandle) renewAfter() time.Duration { after := float64(h.leaseDuration().Nanoseconds()) * renewalStartPercent afterTime := h.acquired.Add(time.Duration(after)) return afterTime.Sub(time.Now()).Round(time.Second) } func (h *VaultHandle) leaseDuration() time.Duration { duration := time.Duration(h.secret.LeaseDuration) * time.Second if h.isAuthToken() { return time.Duration(h.secret.Auth.LeaseDuration) * time.Second } return duration } func (h *VaultHandle) renew(ctx context.Context, c VaultServiceClient, inc int) (err error) { var s *api.Secret vctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() if h.isAuthToken() { s, err = c.Auth().Token().RenewTokenAsSelfWithContext(vctx, h.secret.Auth.ClientToken, inc) if err != nil { return err } } else { s, err = c.Sys().RenewWithContext(vctx, h.secret.LeaseID, inc) if err != nil { return err } } h.secret = s h.acquired = time.Now() return nil } func (h *VaultHandle) Reference() string { return h.name } type VaultServiceClient interface { Auth() *api.Auth Sys() *api.Sys Token() string } type VaultClient struct { sync.Mutex client VaultServiceClient logical *api.Logical auth *approle.AppRoleAuth secrets vaultRenewalMinHeap renewIncrement int notifications chan Renewal } var _ Client = (*VaultClient)(nil) var _ ClientManager = (*VaultClient)(nil) type VaultClientConfig struct { Host string `env:"VAULT_ADDR"` Token string `env:"VAULT_TOKEN"` RoleId string `env:"VAULT_ROLE_ID"` RoleSecret string `env:"VAULT_SECRET_ID"` Increment int `env:"VAULT_INCREMENT"` AppRoleAuth *approle.AppRoleAuth } func (c *VaultClientConfig) Validate() error { if c.Host == "" { return fmt.Errorf("VaultClientConfig: Vault host is not specified") } // The presence of a token is always assumed to be valid, client errors // will occur otherwise. if c.Token != "" { return nil } // This constructor does a bunch of validation internally so just let // it do its thing and return any errors from that directly to the // user. ar, err := approle.NewAppRoleAuth(c.RoleId, &approle.SecretID{FromString: c.RoleSecret}) if err != nil { return fmt.Errorf("VaultClientConfig: AppRole credentials invalid: %w", err) } c.AppRoleAuth = ar return nil } // NewVaultClient will attempt to create a secrets.Client from the // passed config. Config can be nil, in which case an attempt will // be made to load the configuration from environment variables. See // VaultClientConfig for the expected names of those variables. func NewVaultClient(cfg *VaultClientConfig) (ClientManager, error) { if cfg == nil { cfg = &VaultClientConfig{} } if err := glos.UnmarshalEnvironment(cfg); err != nil { return nil, err } if err := cfg.Validate(); err != nil { return nil, err } vc, err := api.NewClient(api.DefaultConfig()) if err != nil { return nil, fmt.Errorf("NewVaultClient: error building client config: %w", err) } vc.SetAddress(cfg.Host) if cfg.Token != "" { vc.SetToken(cfg.Token) } c := &VaultClient{ client: vc, logical: vc.Logical(), secrets: vaultRenewalMinHeap{}, notifications: make(chan Renewal, notificationChanLen), auth: cfg.AppRoleAuth, renewIncrement: cfg.Increment, } c.secrets.Init() if c.renewIncrement == 0 { c.renewIncrement = defaultIncrement } return c, nil } func (c *VaultClient) Notifications() <-chan Renewal { return c.notifications } func (c *VaultClient) Authenticate(ctx context.Context) error { if c.auth == nil { return c.authToken(ctx) } else { return c.authAppRole(ctx) } } // VaultToken is not part of the official API but is exposed for // clients that need to gain access to this for some reason. There are // no compatibility guarantees with this method and it's use limits // portability. func (c *VaultClient) VaultToken() string { return c.client.Token() } func (c *VaultClient) authToken(ctx context.Context) error { if c.client.Token() == "" { return fmt.Errorf("Authenticate: unable to authenticate, neither token nor approle provided") } vctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() secret, err := c.client.Auth().Token().LookupSelfWithContext(vctx) if err != nil { return err } // Looking up self does not return an auth token just a map of data // about the current token. Convert this into a SecretAuth so that // downstream renewal code does the right thing. secret.Auth = &api.SecretAuth{} if err := mapstructure.Decode(secret.Data, secret.Auth); err != nil { return err } secret.Auth.ClientToken = c.client.Token() c.makeHandle("login", secret) return nil } func (c *VaultClient) authAppRole(ctx context.Context) error { vctx, cancel := context.WithTimeout(ctx, defaultTimeout) defer cancel() s, err := c.client.Auth().Login(vctx, c.auth) if err != nil { return fmt.Errorf("Authenticate: error logging in to vault: %w", err) } c.makeHandle("login", s) return err } // makeHandle creates a secret handle and schedules it for renewal if it // is renewable. func (c *VaultClient) makeHandle(name string, s *api.Secret) Handle { h := &VaultHandle{ name: name, critical: true, // Everything is critical unless marked otherwise acquired: time.Now(), secret: s, } // If this is renewable then schedule it for renewal if (s.Auth != nil && s.Auth.Renewable) || s.Renewable { c.Lock() c.secrets.PushHeap(h) c.Unlock() } return h } func (c *VaultClient) writeHandle(ctx context.Context, prefix, suffix string, data map[string]any) (Handle, error) { key := suffix if prefix != "" { key = path.Join(prefix, suffix) } s, err := c.logical.WriteWithContext(ctx, key, data) if err != nil { return nil, fmt.Errorf("writeHandle: error writing to Vault: %w", err) } return c.makeHandle(key, s), nil } func (c *VaultClient) read(ctx context.Context, prefix, suffix string) (Handle, error) { key := suffix if prefix != "" { key = path.Join(prefix, suffix) } s, err := c.logical.ReadWithContext(ctx, key) if err != nil { return nil, fmt.Errorf("read: error reading from Vault: %w", err) } return c.makeHandle(key, s), nil } func (c *VaultClient) isRecoverableDbError(err error) bool { var apiErr *api.ResponseError return errors.Is(api.ErrSecretNotFound, err) || (errors.As(err, &apiErr) && apiErr.StatusCode == http.StatusForbidden) } func (c *VaultClient) DatabaseCredential(ctx context.Context, suffix string) (*Credential, Handle, error) { cred, hnd, err := c.databaseCredentialDynamic(ctx, suffix) if err != nil { if c.isRecoverableDbError(err) { cred, hnd, err = c.databaseCredentialStatic(ctx, suffix) } } return cred, hnd, err } func (c *VaultClient) databaseCredentialDynamic(ctx context.Context, suffix string) (*Credential, Handle, error) { h, err := c.read(ctx, "database/creds", suffix) if err != nil { return nil, nil, err } vh := h.(*VaultHandle) var d Credential if err = mapstructure.Decode(vh.secret.Data, &d); err != nil { return nil, nil, fmt.Errorf("databaseCredentialStatic: error decoding secret: %w", err) } return &d, h, nil } func (c *VaultClient) databaseCredentialStatic(ctx context.Context, suffix string) (*Credential, Handle, error) { h, err := c.read(ctx, "database/static-creds", suffix) if err != nil { return nil, nil, err } var d Credential if err = mapstructure.Decode(h.(*VaultHandle).secret.Data, &d); err != nil { return nil, nil, fmt.Errorf("databaseCredentialStatic: error decoding secret: %w", err) } return &d, h, nil } func (c *VaultClient) Secret(ctx context.Context, suffix string, out any) (Handle, error) { h, err := c.read(ctx, "kv/data", suffix) if err != nil { return nil, err } if err = mapstructure.Decode(h.(*VaultHandle).secret.Data["data"], out); err != nil { return nil, err } return h, nil } func (c *VaultClient) RawSecret(ctx context.Context, path string, out any) (Handle, error) { h, err := c.read(ctx, "", path) if err != nil { return nil, err } if err = mapstructure.Decode(h.(*VaultHandle).secret.Data["data"], out); err != nil { return nil, err } return h, nil } func (c *VaultClient) AWSAssumeRoleSimple(ctx context.Context, name string) (*AWSCredential, Handle, error) { return c.AWSAssumeRole(ctx, name, fmt.Sprintf("%s-%d", name, time.Now().UnixNano()), time.Hour) } func (c *VaultClient) AWSAssumeRole(ctx context.Context, name string, sessionName string, ttl time.Duration) (*AWSCredential, Handle, error) { h, err := c.writeHandle(ctx, "aws/sts", name, map[string]any{ "role_session_name": sessionName, "ttl": ttl.String(), }) if err != nil { return nil, nil, err } var d AWSCredential if err = mapstructure.Decode(h.(*VaultHandle).secret.Data, &d); err != nil { return nil, nil, fmt.Errorf("AWSIAMUser: error decoding secret: %w", err) } return nil, nil, nil } func (c *VaultClient) AWSIAMUser(ctx context.Context, name string) (*AWSCredential, Handle, error) { h, err := c.read(ctx, "aws/creds", name) if err != nil { return nil, nil, err } var d AWSCredential if err = mapstructure.Decode(h.(*VaultHandle).secret.Data, &d); err != nil { return nil, nil, fmt.Errorf("AWSIAMUser: error decoding secret: %w", err) } return &d, h, nil } func (c *VaultClient) WriteSecret(ctx context.Context, suffix string, in any) error { inb, err := json.Marshal(in) if err != nil { return fmt.Errorf("WriteSecret: error encoding json: %w", err) } if _, err = c.logical.WriteBytesWithContext(ctx, path.Join("kv/data", suffix), inb); err != nil { return fmt.Errorf("WriteSecret: error writing to vault: %w", err) } return nil } func (c *VaultClient) Encrypt(ctx context.Context, suffix string, data []byte) (string, error) { s, err := c.logical.WriteWithContext( ctx, path.Join("transit/encrypt", suffix), map[string]any{"plaintext": base64.StdEncoding.EncodeToString(data)}, ) if err != nil { return "", fmt.Errorf("Encrypt: unable to write to vault: %w", err) } return s.Data["ciphertext"].(string), nil } func (c *VaultClient) Decrypt(ctx context.Context, suffix, data string) ([]byte, error) { s, err := c.logical.WriteWithContext( ctx, path.Join("transit/decrypt", suffix), map[string]any{"ciphertext": data}, ) if err != nil { return nil, fmt.Errorf("Decrypt: unable to write to vault: %w", err) } d, err := base64.StdEncoding.DecodeString(s.Data["plaintext"].(string)) if err != nil { return nil, fmt.Errorf("Decrypt: unable to base64 decode plaintext: %w", err) } return d, nil } func (c *VaultClient) Destroy(h Handle) error { c.Lock() defer c.Unlock() c.secrets.FindRemove(h.(*VaultHandle)) return nil } func (c *VaultClient) MakeNonCritical(h Handle) error { h.(*VaultHandle).critical = false return nil } func (c *VaultClient) renewAttempt(ctx context.Context) (next time.Duration) { c.Lock() defer c.Unlock() // In the absence of any other time, run once a second next = 5 * time.Second if c.secrets.Len() < 1 { return } for { s := c.secrets.PopHeap() if s.renewAfter() < renewalWindow { // Underlying client does backoff and retry c.notifications <- Renewal{ Name: s.name, Critical: s.critical, Time: time.Now(), Error: s.renew(ctx, c.client, c.renewIncrement), } c.secrets.PushHeap(s) } else { c.secrets.PushHeap(s) next = c.secrets.Root().renewAfter() break } } return next } func (c *VaultClient) Run(ctx context.Context, wg *sync.WaitGroup) error { wg.Add(1) defer wg.Done() for { sleepTime := c.renewAttempt(ctx) select { case <-time.After(sleepTime): continue case <-ctx.Done(): return nil } } }