diff options
Diffstat (limited to 'vault/client.go')
-rw-r--r-- | vault/client.go | 330 |
1 files changed, 330 insertions, 0 deletions
diff --git a/vault/client.go b/vault/client.go new file mode 100644 index 0000000..2f645d4 --- /dev/null +++ b/vault/client.go | |||
@@ -0,0 +1,330 @@ | |||
1 | package vault | ||
2 | |||
3 | import ( | ||
4 | "context" | ||
5 | "fmt" | ||
6 | "os" | ||
7 | "path" | ||
8 | "sync" | ||
9 | "time" | ||
10 | |||
11 | "github.com/hashicorp/vault/api" | ||
12 | "github.com/hashicorp/vault/api/auth/approle" | ||
13 | "github.com/mitchellh/mapstructure" | ||
14 | ) | ||
15 | |||
16 | type VaultClient interface { | ||
17 | LoginApprole(c context.Context, roleId string, secretId string) error | ||
18 | |||
19 | DbStaticCredential(c context.Context, suffix string) (*VaultUsernamePassword, error) | ||
20 | DbCredential(c context.Context, suffix string) (*VaultUsernamePassword, error) | ||
21 | |||
22 | KV(c context.Context, suffix string, out interface{}) (*VaultSecret, error) | ||
23 | KVApiKey(c context.Context, suffix string) (*VaultApiKey, error) | ||
24 | KVCredential(c context.Context, suffix string) (*VaultUsernamePassword, error) | ||
25 | |||
26 | Destroy(HasSecret) | ||
27 | Run(ctx context.Context, wg *sync.WaitGroup) error | ||
28 | } | ||
29 | |||
30 | type HasSecret interface { | ||
31 | VaultSecret() *VaultSecret | ||
32 | } | ||
33 | |||
34 | // VaultSecret is an opaque reference to a secret from Vault. It is | ||
35 | // meant to be given to the Destroy function to check-in and destroy | ||
36 | // unneeded credentials. Everything returned from the client has a | ||
37 | // VaultSecret and implements HasSecret for that purpose. If the | ||
38 | // credential is not renewable then destroying it is a no-op. | ||
39 | type VaultSecret struct { | ||
40 | s *api.Secret | ||
41 | n string | ||
42 | } | ||
43 | |||
44 | func (s *VaultSecret) VaultSecret() *VaultSecret { | ||
45 | return s | ||
46 | } | ||
47 | |||
48 | type VaultApiKey struct { | ||
49 | Key string `json:"key"` | ||
50 | s *VaultSecret | ||
51 | } | ||
52 | |||
53 | func (k *VaultApiKey) VaultSecret() *VaultSecret { | ||
54 | return k.s | ||
55 | } | ||
56 | |||
57 | type VaultUsernamePassword struct { | ||
58 | Username string `json:"username"` | ||
59 | Password string `json:"password"` | ||
60 | s *VaultSecret | ||
61 | } | ||
62 | |||
63 | func (k *VaultUsernamePassword) VaultSecret() *VaultSecret { | ||
64 | return k.s | ||
65 | } | ||
66 | |||
67 | type Renewal struct { | ||
68 | RenewedAt time.Time | ||
69 | Name string | ||
70 | } | ||
71 | |||
72 | type vaultClient struct { | ||
73 | sync.Mutex | ||
74 | c *api.Client | ||
75 | lc *api.Logical | ||
76 | wg *sync.WaitGroup | ||
77 | watcherDone chan error | ||
78 | watchers map[string]*api.LifetimeWatcher | ||
79 | renewInfo chan *Renewal | ||
80 | } | ||
81 | |||
82 | // NewApproleClientEnv is a convenience function to create a new | ||
83 | // VaultClient based on the environment, start it, and login using | ||
84 | // Approle authentication. | ||
85 | // | ||
86 | // The following environment variables are used and must be present: | ||
87 | // | ||
88 | // VAULT_ADDR - URL to Vault server (of form https://host:port/) | ||
89 | // VAULT_ROLE_ID - Role ID used for Approle authentication | ||
90 | // VAULT_SECRET_ID - Secret ID used for Approle authentication | ||
91 | // | ||
92 | func NewApproleClientEnv(ctx context.Context, wg *sync.WaitGroup, renewInfo chan *Renewal) (VaultClient, error) { | ||
93 | vaultHost := os.Getenv("VAULT_ADDR") | ||
94 | if vaultHost == "" { | ||
95 | return nil, fmt.Errorf("NewApproleClientEnv: VAULT_ADDR is not set in environment") | ||
96 | } | ||
97 | |||
98 | roleId := os.Getenv("VAULT_ROLE_ID") | ||
99 | if roleId == "" { | ||
100 | return nil, fmt.Errorf("NewApproleClientEnv: VAULT_ROLE_ID is not set in environment") | ||
101 | } | ||
102 | |||
103 | secretId := os.Getenv("VAULT_SECRET_ID") | ||
104 | if secretId == "" { | ||
105 | return nil, fmt.Errorf("NewApproleClientEnv: VAULT_SECRET_ID is not set in environment") | ||
106 | } | ||
107 | |||
108 | vc, err := NewVaultClient(vaultHost, renewInfo) | ||
109 | if err != nil { | ||
110 | return nil, fmt.Errorf("NewApproleClientEnv: error creating client %w", err) | ||
111 | } | ||
112 | |||
113 | go vc.Run(ctx, wg) | ||
114 | |||
115 | if err = vc.LoginApprole(ctx, roleId, secretId); err != nil { | ||
116 | return nil, fmt.Errorf("NewApproleClientEnv: error logging in to vault %w", err) | ||
117 | } | ||
118 | |||
119 | return vc, nil | ||
120 | } | ||
121 | |||
122 | func NewVaultClient(host string, renewInfo chan *Renewal) (VaultClient, error) { | ||
123 | cfg := api.DefaultConfig() | ||
124 | cfg.Address = host | ||
125 | |||
126 | c, err := api.NewClient(cfg) | ||
127 | if err != nil { | ||
128 | return nil, err | ||
129 | } | ||
130 | |||
131 | return &vaultClient{ | ||
132 | c: c, | ||
133 | lc: c.Logical(), | ||
134 | renewInfo: renewInfo, | ||
135 | watcherDone: make(chan error, 10), | ||
136 | watchers: map[string]*api.LifetimeWatcher{}, | ||
137 | }, nil | ||
138 | } | ||
139 | |||
140 | func (c *vaultClient) watchWatcher(w *api.LifetimeWatcher, name string) { | ||
141 | c.wg.Add(1) | ||
142 | defer c.wg.Done() | ||
143 | |||
144 | for { | ||
145 | select { | ||
146 | case err := <-w.DoneCh(): | ||
147 | if err != nil { | ||
148 | c.watcherDone <- err | ||
149 | } | ||
150 | return | ||
151 | case r := <-w.RenewCh(): | ||
152 | // Report this so consumers can do their own reporting, if not | ||
153 | // provided we just read this to drain the chan and throw it away. | ||
154 | if c.renewInfo != nil { | ||
155 | c.renewInfo <- &Renewal{ | ||
156 | Name: name, | ||
157 | RenewedAt: r.RenewedAt, | ||
158 | } | ||
159 | } | ||
160 | } | ||
161 | } | ||
162 | } | ||
163 | |||
164 | func (c *vaultClient) addWatcher(name string, s *api.Secret) error { | ||
165 | w, err := c.c.NewLifetimeWatcher(&api.LifetimeWatcherInput{ | ||
166 | Secret: s, | ||
167 | }) | ||
168 | if err != nil { | ||
169 | return err | ||
170 | } | ||
171 | |||
172 | c.Lock() | ||
173 | c.watchers[name] = w | ||
174 | c.Unlock() | ||
175 | |||
176 | go w.Start() | ||
177 | go c.watchWatcher(w, name) | ||
178 | |||
179 | return nil | ||
180 | } | ||
181 | |||
182 | func (c *vaultClient) read(ctx context.Context, prefix, suffix string) (*api.Secret, string, error) { | ||
183 | key := path.Join(prefix, suffix) | ||
184 | |||
185 | s, err := c.lc.ReadWithContext(ctx, key) | ||
186 | if err != nil { | ||
187 | return nil, "", err | ||
188 | } | ||
189 | |||
190 | if s.Renewable { | ||
191 | return s, key, c.addWatcher(key, s) | ||
192 | } | ||
193 | |||
194 | return s, key, nil | ||
195 | } | ||
196 | |||
197 | func (c *vaultClient) stop() { | ||
198 | c.Lock() | ||
199 | defer c.Unlock() | ||
200 | |||
201 | for _, w := range c.watchers { | ||
202 | w.Stop() | ||
203 | } | ||
204 | } | ||
205 | |||
206 | func (c *vaultClient) Run(ctx context.Context, wg *sync.WaitGroup) error { | ||
207 | c.Lock() | ||
208 | c.wg = wg | ||
209 | c.Unlock() | ||
210 | |||
211 | c.wg.Add(1) | ||
212 | defer c.wg.Done() | ||
213 | |||
214 | for { | ||
215 | select { | ||
216 | case <-ctx.Done(): | ||
217 | c.stop() | ||
218 | return nil | ||
219 | case err := <-c.watcherDone: | ||
220 | c.stop() | ||
221 | return err | ||
222 | } | ||
223 | } | ||
224 | } | ||
225 | |||
226 | func (c *vaultClient) Destroy(s HasSecret) { | ||
227 | vs := s.VaultSecret() | ||
228 | if vs == nil || vs.n == "" || vs.s == nil { | ||
229 | return | ||
230 | } | ||
231 | |||
232 | c.Lock() | ||
233 | defer c.Unlock() | ||
234 | |||
235 | if w, ok := c.watchers[vs.n]; ok { | ||
236 | delete(c.watchers, vs.n) | ||
237 | w.Stop() | ||
238 | } | ||
239 | |||
240 | // TODO: Delete dynamic credentials like DB sessions from Vault | ||
241 | |||
242 | // Drop references to the secret so that even if the client holds on to | ||
243 | // it we free the RAM. | ||
244 | vs.s = nil | ||
245 | vs.n = "" | ||
246 | } | ||
247 | |||
248 | func (c *vaultClient) LoginApprole(ctx context.Context, roleId string, secretId string) error { | ||
249 | a, err := approle.NewAppRoleAuth(roleId, &approle.SecretID{FromString: secretId}) | ||
250 | if err != nil { | ||
251 | return err | ||
252 | } | ||
253 | |||
254 | s, err := c.c.Auth().Login(ctx, a) | ||
255 | if err != nil { | ||
256 | return err | ||
257 | } | ||
258 | |||
259 | // This credential can not be destroyed like the others | ||
260 | return c.addWatcher("login", s) | ||
261 | } | ||
262 | |||
263 | func (c *vaultClient) DbStaticCredential(ctx context.Context, suffix string) (*VaultUsernamePassword, error) { | ||
264 | s, k, err := c.read(ctx, "database/static-creds", suffix) | ||
265 | if err != nil { | ||
266 | return nil, err | ||
267 | } | ||
268 | |||
269 | var d VaultUsernamePassword | ||
270 | if err = mapstructure.Decode(s.Data, &d); err != nil { | ||
271 | return nil, err | ||
272 | } | ||
273 | |||
274 | d.s = &VaultSecret{s: s, n: k} | ||
275 | |||
276 | return &d, nil | ||
277 | } | ||
278 | |||
279 | func (c *vaultClient) DbCredential(ctx context.Context, suffix string) (*VaultUsernamePassword, error) { | ||
280 | s, k, err := c.read(ctx, "database/creds", suffix) | ||
281 | if err != nil { | ||
282 | return nil, err | ||
283 | } | ||
284 | |||
285 | var d VaultUsernamePassword | ||
286 | if err = mapstructure.Decode(s.Data, &d); err != nil { | ||
287 | return nil, err | ||
288 | } | ||
289 | |||
290 | d.s = &VaultSecret{s: s, n: k} | ||
291 | |||
292 | return &d, nil | ||
293 | } | ||
294 | |||
295 | func (c *vaultClient) KV(ctx context.Context, suffix string, out interface{}) (*VaultSecret, error) { | ||
296 | s, k, err := c.read(ctx, "kv/data", suffix) | ||
297 | if err != nil { | ||
298 | return nil, err | ||
299 | } | ||
300 | |||
301 | if err = mapstructure.Decode(s.Data["data"], out); err != nil { | ||
302 | return nil, err | ||
303 | } | ||
304 | |||
305 | return &VaultSecret{s: s, n: k}, nil | ||
306 | } | ||
307 | |||
308 | func (c *vaultClient) KVApiKey(ctx context.Context, suffix string) (*VaultApiKey, error) { | ||
309 | var ak VaultApiKey | ||
310 | s, err := c.KV(ctx, suffix, &ak) | ||
311 | if err != nil { | ||
312 | return nil, err | ||
313 | } | ||
314 | |||
315 | ak.s = s | ||
316 | |||
317 | return &ak, nil | ||
318 | } | ||
319 | |||
320 | func (c *vaultClient) KVCredential(ctx context.Context, suffix string) (*VaultUsernamePassword, error) { | ||
321 | var ak VaultUsernamePassword | ||
322 | s, err := c.KV(ctx, suffix, &ak) | ||
323 | if err != nil { | ||
324 | return nil, err | ||
325 | } | ||
326 | |||
327 | ak.s = s | ||
328 | |||
329 | return &ak, nil | ||
330 | } | ||