aboutsummaryrefslogtreecommitdiff
path: root/netbox/django-driver.py
blob: 65a91363b14d4e8adb6ce7524f89f2b3192354b5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import threading
from datetime import datetime, timedelta

from django.core.exceptions import ImproperlyConfigured
from django.contrib.vault_client import SimpleVaultClient, Credential
from django.db.backends.postgresql.base import DatabaseWrapper as OrigWrapper


def _is_affirmative(value):
    value = "" if not value else value
    return value.lower() in ["yes", "true", "on", "1"]


def _must_get(store, key):
    value = store.get(key)

    if not value:
        raise ImproperlyConfigured(
            f"Database parameter {key} is required but not set.")

    return value


class DatabaseWrapper(OrigWrapper):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._vault_cache_lock = threading.Lock()
        self._vault_cred_cache = Credential.empty()

    def close_if_unusable_or_obsolete(self):
        super().close_if_unusable_or_obsolete()

        if self.connection is None:
            return

        with self._vault_cache_lock:
            if not self._vault_cred_cache.is_valid:
                self.close()

    # All of this is done under lock
    def _get_vault_cred(self):
        print("Getting credentials from vault")
        params = self.settings_dict

        verify = not _is_affirmative(params.get("VAULT_SKIP_VERIFY"))
        url = _must_get(params, "VAULT_ADDR")
        token = params.get("VAULT_TOKEN")
        db_role_name = _must_get(params, "VAULT_DB_ROLE_NAME")
        role_id = _must_get(params, "VAULT_ROLE_ID")
        role_secret = _must_get(params, "VAULT_SECRET_ID")

        client = SimpleVaultClient(url, role_id, role_secret, verify)

        self._vault_cred_cache = client.get_db_credential(db_role_name)

    def get_connection_params(self):
        conn_params = super().get_connection_params()

        # Do the fetch under lock to prevent multiple threads from piling onto
        # the vault server
        with self._vault_cache_lock:
            if not self._vault_cred_cache.is_valid:
                self._get_vault_cred()

            conn_params["user"] = self._vault_cred_cache.username
            conn_params["password"] = self._vault_cred_cache.password

        return conn_params