package mongodb import ( "context" "fmt" "net/url" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" "code.crute.us/mcrute/golib/vault" ) // AnyInTopLevelArray is just a convenience method so apps don't have to repeat // this ugly bson expression. func AnyInTopLevelArray(k string, v interface{}) bson.M { return bson.M{k: bson.M{"$all": bson.A{v}}} } type Mongo struct { client *mongo.Client db *mongo.Database } func Connect(ctx context.Context, uri, materialSet string, vc vault.VaultClient) (*Mongo, error) { db := &Mongo{} cred, err := vc.DbCredential(ctx, materialSet) if err != nil { return nil, err } u, err := url.Parse(uri) if err != nil { return nil, err } u.User = url.UserPassword(cred.Username, cred.Password) cs, err := connstring.ParseAndValidate(u.String()) if err != nil { return nil, err } client, err := mongo.Connect(ctx, options.Client().ApplyURI(u.String())) if err != nil { return nil, err } db.client = client db.db = client.Database(cs.Database) return db, nil } func (m *Mongo) Collection(name string) *mongo.Collection { return m.db.Collection(name) } func (m *Mongo) FindAllByFilter(ctx context.Context, cn string, filter interface{}, out interface{}) error { res, err := m.db.Collection(cn).Find(ctx, filter) if err != nil { return err } if err = res.All(ctx, out); err != nil { return err } return nil } func (m *Mongo) FindAll(ctx context.Context, cn string, out interface{}) error { return m.FindAllByFilter(ctx, cn, bson.D{}, out) } func (m *Mongo) FindOneByFilter(ctx context.Context, cn string, filter interface{}, out interface{}) error { if err := m.db.Collection(cn).FindOne(ctx, filter).Decode(out); err != nil { return err } return nil } func (m *Mongo) FindOneById(ctx context.Context, cn string, id string, out interface{}) error { return m.FindOneByFilter(ctx, cn, bson.M{"_id": id}, out) } func (m *Mongo) InsertOne(ctx context.Context, cn string, in interface{}) error { _, err := m.db.Collection(cn).InsertOne(ctx, in) if err != nil { return err } return nil } func (m *Mongo) ReplaceOneByFilter(ctx context.Context, cn string, filter interface{}, in interface{}) error { tp := true opts := &options.ReplaceOptions{Upsert: &tp} if _, err := m.db.Collection(cn).ReplaceOne(ctx, filter, in, opts); err != nil { return err } return nil } func (m *Mongo) ReplaceOneById(ctx context.Context, cn string, id string, in interface{}) error { return m.ReplaceOneByFilter(ctx, cn, bson.M{"_id": id}, in) } func (m *Mongo) DeleteOneByFilter(ctx context.Context, cn string, filter interface{}) error { dr, err := m.db.Collection(cn).DeleteOne(ctx, filter) if err != nil { return err } if dr.DeletedCount != 1 { return fmt.Errorf("Invalid deletion record count %d not 1", dr.DeletedCount) } return nil } func (m *Mongo) DeleteOneById(ctx context.Context, cn string, id string) error { return m.DeleteOneByFilter(ctx, cn, bson.M{"_id": id}) }