From d25729cef991e6136eede4931e3d46a76d473391 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Sun, 22 May 2022 00:59:40 -0700 Subject: db/mongodb: allow implying many params --- db/mongodb/client.go | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/db/mongodb/client.go b/db/mongodb/client.go index c423d99..1f3815d 100644 --- a/db/mongodb/client.go +++ b/db/mongodb/client.go @@ -3,7 +3,9 @@ package mongodb import ( "context" "fmt" + "net" "net/url" + "strings" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -24,20 +26,46 @@ type Mongo struct { db *mongo.Database } -func Connect(ctx context.Context, uri, materialSet string, vc vault.VaultClient) (*Mongo, error) { +func Connect(ctx context.Context, uri string, vc vault.VaultClient) (*Mongo, error) { db := &Mongo{} - cred, err := vc.DbCredential(ctx, materialSet) + // Prefix uri with mongodb:// unless it already includes one of the + // standard prefixes (only these two are valid). Otherwise if scheme is + // omitted then url parsing will fail to capture the username for Vault + // lookup. + if !strings.HasPrefix(uri, "mongodb://") && !strings.HasPrefix(uri, "mongodb+srv://") { + uri = "mongodb://" + uri + } + + u, err := url.Parse(uri) if err != nil { return nil, err } - u, err := url.Parse(uri) + // The username provided by the user (there should be no + // password) will be a reference to a vault material with the + // prefix database/creds/. This needs to be replaced with the real + // username/password pair fetched from Vault before attempting to + // connect. + cred, err := vc.DbCredential(ctx, u.User.Username()) if err != nil { return nil, err } u.User = url.UserPassword(cred.Username, cred.Password) + // User may imply the default port + if u.Port() == "" { + u.Host = net.JoinHostPort(u.Host, "27017") + } + + // Users should generally authenticate against the admin collection so + // they should only specify it if they need to override that. + if u.Query().Get("authSource") == "" { + pq := u.Query() + pq.Add("authSource", "admin") + u.RawQuery = pq.Encode() + } + cs, err := connstring.ParseAndValidate(u.String()) if err != nil { return nil, err -- cgit v1.2.3