diff options
Diffstat (limited to 'netbox/django-driver.py')
-rw-r--r-- | netbox/django-driver.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/netbox/django-driver.py b/netbox/django-driver.py new file mode 100644 index 0000000..65a9136 --- /dev/null +++ b/netbox/django-driver.py | |||
@@ -0,0 +1,69 @@ | |||
1 | import threading | ||
2 | from datetime import datetime, timedelta | ||
3 | |||
4 | from django.core.exceptions import ImproperlyConfigured | ||
5 | from django.contrib.vault_client import SimpleVaultClient, Credential | ||
6 | from django.db.backends.postgresql.base import DatabaseWrapper as OrigWrapper | ||
7 | |||
8 | |||
9 | def _is_affirmative(value): | ||
10 | value = "" if not value else value | ||
11 | return value.lower() in ["yes", "true", "on", "1"] | ||
12 | |||
13 | |||
14 | def _must_get(store, key): | ||
15 | value = store.get(key) | ||
16 | |||
17 | if not value: | ||
18 | raise ImproperlyConfigured( | ||
19 | f"Database parameter {key} is required but not set.") | ||
20 | |||
21 | return value | ||
22 | |||
23 | |||
24 | class DatabaseWrapper(OrigWrapper): | ||
25 | |||
26 | def __init__(self, *args, **kwargs): | ||
27 | super().__init__(*args, **kwargs) | ||
28 | self._vault_cache_lock = threading.Lock() | ||
29 | self._vault_cred_cache = Credential.empty() | ||
30 | |||
31 | def close_if_unusable_or_obsolete(self): | ||
32 | super().close_if_unusable_or_obsolete() | ||
33 | |||
34 | if self.connection is None: | ||
35 | return | ||
36 | |||
37 | with self._vault_cache_lock: | ||
38 | if not self._vault_cred_cache.is_valid: | ||
39 | self.close() | ||
40 | |||
41 | # All of this is done under lock | ||
42 | def _get_vault_cred(self): | ||
43 | print("Getting credentials from vault") | ||
44 | params = self.settings_dict | ||
45 | |||
46 | verify = not _is_affirmative(params.get("VAULT_SKIP_VERIFY")) | ||
47 | url = _must_get(params, "VAULT_ADDR") | ||
48 | token = params.get("VAULT_TOKEN") | ||
49 | db_role_name = _must_get(params, "VAULT_DB_ROLE_NAME") | ||
50 | role_id = _must_get(params, "VAULT_ROLE_ID") | ||
51 | role_secret = _must_get(params, "VAULT_SECRET_ID") | ||
52 | |||
53 | client = SimpleVaultClient(url, role_id, role_secret, verify) | ||
54 | |||
55 | self._vault_cred_cache = client.get_db_credential(db_role_name) | ||
56 | |||
57 | def get_connection_params(self): | ||
58 | conn_params = super().get_connection_params() | ||
59 | |||
60 | # Do the fetch under lock to prevent multiple threads from piling onto | ||
61 | # the vault server | ||
62 | with self._vault_cache_lock: | ||
63 | if not self._vault_cred_cache.is_valid: | ||
64 | self._get_vault_cred() | ||
65 | |||
66 | conn_params["user"] = self._vault_cred_cache.username | ||
67 | conn_params["password"] = self._vault_cred_cache.password | ||
68 | |||
69 | return conn_params | ||