From 763b810ca1a9d755205e49b1246025b83abb5132 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Wed, 27 Jan 2021 04:41:00 +0000 Subject: Add netbox --- netbox/django-driver.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 netbox/django-driver.py (limited to 'netbox/django-driver.py') diff --git a/netbox/django-driver.py b/netbox/django-driver.py new file mode 100644 index 0000000..65a9136 --- /dev/null +++ b/netbox/django-driver.py @@ -0,0 +1,69 @@ +import threading +from datetime import datetime, timedelta + +from django.core.exceptions import ImproperlyConfigured +from django.contrib.vault_client import SimpleVaultClient, Credential +from django.db.backends.postgresql.base import DatabaseWrapper as OrigWrapper + + +def _is_affirmative(value): + value = "" if not value else value + return value.lower() in ["yes", "true", "on", "1"] + + +def _must_get(store, key): + value = store.get(key) + + if not value: + raise ImproperlyConfigured( + f"Database parameter {key} is required but not set.") + + return value + + +class DatabaseWrapper(OrigWrapper): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._vault_cache_lock = threading.Lock() + self._vault_cred_cache = Credential.empty() + + def close_if_unusable_or_obsolete(self): + super().close_if_unusable_or_obsolete() + + if self.connection is None: + return + + with self._vault_cache_lock: + if not self._vault_cred_cache.is_valid: + self.close() + + # All of this is done under lock + def _get_vault_cred(self): + print("Getting credentials from vault") + params = self.settings_dict + + verify = not _is_affirmative(params.get("VAULT_SKIP_VERIFY")) + url = _must_get(params, "VAULT_ADDR") + token = params.get("VAULT_TOKEN") + db_role_name = _must_get(params, "VAULT_DB_ROLE_NAME") + role_id = _must_get(params, "VAULT_ROLE_ID") + role_secret = _must_get(params, "VAULT_SECRET_ID") + + client = SimpleVaultClient(url, role_id, role_secret, verify) + + self._vault_cred_cache = client.get_db_credential(db_role_name) + + def get_connection_params(self): + conn_params = super().get_connection_params() + + # Do the fetch under lock to prevent multiple threads from piling onto + # the vault server + with self._vault_cache_lock: + if not self._vault_cred_cache.is_valid: + self._get_vault_cred() + + conn_params["user"] = self._vault_cred_cache.username + conn_params["password"] = self._vault_cred_cache.password + + return conn_params -- cgit v1.2.3