package vault import ( "context" "fmt" "os" "path" "sync" "time" "github.com/hashicorp/vault/api" "github.com/hashicorp/vault/api/auth/approle" "github.com/mitchellh/mapstructure" ) type VaultClient interface { LoginApproleEnv(c context.Context) error LoginApprole(c context.Context, roleId string, secretId string) error DbStaticCredential(c context.Context, suffix string) (*VaultUsernamePassword, error) DbCredential(c context.Context, suffix string) (*VaultUsernamePassword, error) KV(c context.Context, suffix string, out interface{}) (*VaultSecret, error) KVApiKey(c context.Context, suffix string) (*VaultApiKey, error) KVCredential(c context.Context, suffix string) (*VaultUsernamePassword, error) Destroy(HasSecret) Run(ctx context.Context, wg *sync.WaitGroup) error } type HasSecret interface { VaultSecret() *VaultSecret } // VaultSecret is an opaque reference to a secret from Vault. It is // meant to be given to the Destroy function to check-in and destroy // unneeded credentials. Everything returned from the client has a // VaultSecret and implements HasSecret for that purpose. If the // credential is not renewable then destroying it is a no-op. type VaultSecret struct { s *api.Secret n string } func (s *VaultSecret) VaultSecret() *VaultSecret { return s } type VaultApiKey struct { Key string `json:"key"` s *VaultSecret } func (k *VaultApiKey) VaultSecret() *VaultSecret { return k.s } type VaultUsernamePassword struct { Username string `json:"username"` Password string `json:"password"` s *VaultSecret } func (k *VaultUsernamePassword) VaultSecret() *VaultSecret { return k.s } type Renewal struct { RenewedAt time.Time Name string } type vaultClient struct { sync.Mutex c *api.Client lc *api.Logical wg *sync.WaitGroup watcherDone chan error watchers map[string]*api.LifetimeWatcher renewInfo chan *Renewal } // NewClientEnv is a convenience function to create a new VaultClient // based on the environment. // // The following environment variables are used and must be present: // // VAULT_ADDR - URL to Vault server (of form https://host:port/) // func NewClientEnv(renewInfo chan *Renewal) (VaultClient, error) { vaultHost := os.Getenv("VAULT_ADDR") if vaultHost == "" { return nil, fmt.Errorf("NewClientEnv: VAULT_ADDR is not set in environment") } vc, err := NewVaultClient(vaultHost, renewInfo) if err != nil { return nil, fmt.Errorf("NewClientEnv: error creating client %w", err) } return vc, nil } func NewVaultClient(host string, renewInfo chan *Renewal) (VaultClient, error) { cfg := api.DefaultConfig() cfg.Address = host c, err := api.NewClient(cfg) if err != nil { return nil, err } return &vaultClient{ c: c, lc: c.Logical(), renewInfo: renewInfo, watcherDone: make(chan error, 10), watchers: map[string]*api.LifetimeWatcher{}, }, nil } func (c *vaultClient) watchWatcher(w *api.LifetimeWatcher, name string) { c.wg.Add(1) defer c.wg.Done() for { select { case err := <-w.DoneCh(): if err != nil { c.watcherDone <- err } return case r := <-w.RenewCh(): // Report this so consumers can do their own reporting, if not // provided we just read this to drain the chan and throw it away. if c.renewInfo != nil { c.renewInfo <- &Renewal{ Name: name, RenewedAt: r.RenewedAt, } } } } } func (c *vaultClient) addWatcher(name string, s *api.Secret) error { w, err := c.c.NewLifetimeWatcher(&api.LifetimeWatcherInput{ Secret: s, }) if err != nil { return err } c.Lock() c.watchers[name] = w c.Unlock() go w.Start() go c.watchWatcher(w, name) return nil } func (c *vaultClient) read(ctx context.Context, prefix, suffix string) (*api.Secret, string, error) { key := path.Join(prefix, suffix) s, err := c.lc.ReadWithContext(ctx, key) if err != nil { return nil, "", err } if s.Renewable { return s, key, c.addWatcher(key, s) } return s, key, nil } func (c *vaultClient) stop() { c.Lock() defer c.Unlock() for _, w := range c.watchers { w.Stop() } } func (c *vaultClient) Run(ctx context.Context, wg *sync.WaitGroup) error { c.Lock() c.wg = wg c.Unlock() c.wg.Add(1) defer c.wg.Done() for { select { case <-ctx.Done(): c.stop() return nil case err := <-c.watcherDone: c.stop() return err } } } func (c *vaultClient) Destroy(s HasSecret) { vs := s.VaultSecret() if vs == nil || vs.n == "" || vs.s == nil { return } c.Lock() defer c.Unlock() if w, ok := c.watchers[vs.n]; ok { delete(c.watchers, vs.n) w.Stop() } // TODO: Delete dynamic credentials like DB sessions from Vault // Drop references to the secret so that even if the client holds on to // it we free the RAM. vs.s = nil vs.n = "" } func (c *vaultClient) LoginApprole(ctx context.Context, roleId string, secretId string) error { a, err := approle.NewAppRoleAuth(roleId, &approle.SecretID{FromString: secretId}) if err != nil { return err } s, err := c.c.Auth().Login(ctx, a) if err != nil { return err } // This credential can not be destroyed like the others return c.addWatcher("login", s) } func (c *vaultClient) DbStaticCredential(ctx context.Context, suffix string) (*VaultUsernamePassword, error) { s, k, err := c.read(ctx, "database/static-creds", suffix) if err != nil { return nil, err } var d VaultUsernamePassword if err = mapstructure.Decode(s.Data, &d); err != nil { return nil, err } d.s = &VaultSecret{s: s, n: k} return &d, nil } func (c *vaultClient) DbCredential(ctx context.Context, suffix string) (*VaultUsernamePassword, error) { s, k, err := c.read(ctx, "database/creds", suffix) if err != nil { return nil, err } var d VaultUsernamePassword if err = mapstructure.Decode(s.Data, &d); err != nil { return nil, err } d.s = &VaultSecret{s: s, n: k} return &d, nil } func (c *vaultClient) KV(ctx context.Context, suffix string, out interface{}) (*VaultSecret, error) { s, k, err := c.read(ctx, "kv/data", suffix) if err != nil { return nil, err } if err = mapstructure.Decode(s.Data["data"], out); err != nil { return nil, err } return &VaultSecret{s: s, n: k}, nil } func (c *vaultClient) KVApiKey(ctx context.Context, suffix string) (*VaultApiKey, error) { var ak VaultApiKey s, err := c.KV(ctx, suffix, &ak) if err != nil { return nil, err } ak.s = s return &ak, nil } func (c *vaultClient) KVCredential(ctx context.Context, suffix string) (*VaultUsernamePassword, error) { var ak VaultUsernamePassword s, err := c.KV(ctx, suffix, &ak) if err != nil { return nil, err } ak.s = s return &ak, nil } // LoginApproleEnv is a convenience function to login using AppRole // authentication and fetching the role id and secret id from the // environment. // // The following environment variables are used and must be present: // // VAULT_ROLE_ID - Role ID used for Approle authentication // VAULT_SECRET_ID - Secret ID used for Approle authentication // func (c *vaultClient) LoginApproleEnv(ctx context.Context) error { roleId := os.Getenv("VAULT_ROLE_ID") if roleId == "" { return fmt.Errorf("NewApproleClientEnv: VAULT_ROLE_ID is not set in environment") } secretId := os.Getenv("VAULT_SECRET_ID") if secretId == "" { return fmt.Errorf("NewApproleClientEnv: VAULT_SECRET_ID is not set in environment") } if err := c.LoginApprole(ctx, roleId, secretId); err != nil { return fmt.Errorf("NewApproleClientEnv: error logging in to vault %w", err) } return nil }