package cli import ( "fmt" "log" "reflect" "strings" "time" "code.crute.us/mcrute/golib/vault" "github.com/spf13/cobra" ) func MustGetConfig(cmd *cobra.Command, out interface{}) { if err := GetConfig(cmd, out); err != nil { log.Fatal(err) } } func GetConfig(cmd *cobra.Command, out interface{}) error { t := reflect.TypeOf(out).Elem() o := reflect.ValueOf(out).Elem() for i := 0; i < t.NumField(); i++ { tf := t.Field(i) f := o.FieldByName(tf.Name) // Fields with no name are not considered flags name := tf.Tag.Get("flag") if name == "" { continue } // Pretty much only string and struct can be tested req := tf.Tag.Get("flag-required") == "true" switch f.Type().Kind() { case reflect.Bool: v, _ := cmd.Flags().GetBool(name) f.Set(reflect.ValueOf(v)) case reflect.String: v, _ := cmd.Flags().GetString(name) if req && v == "" { log.Fatalf("Flag %s is required but not provided", name) } f.Set(reflect.ValueOf(v)) case reflect.Int: v, _ := cmd.Flags().GetInt(name) f.Set(reflect.ValueOf(v)) case reflect.Int32: v, _ := cmd.Flags().GetInt32(name) f.Set(reflect.ValueOf(v)) case reflect.Int64: if tf.Type.AssignableTo(reflect.TypeOf(time.Duration(0))) { // time.Duration v, _ := cmd.Flags().GetDuration(name) f.Set(reflect.ValueOf(v)) } else { v, _ := cmd.Flags().GetInt64(name) f.Set(reflect.ValueOf(v)) } case reflect.Uint: v, _ := cmd.Flags().GetUint(name) f.Set(reflect.ValueOf(v)) case reflect.Uint32: v, _ := cmd.Flags().GetUint32(name) f.Set(reflect.ValueOf(v)) case reflect.Uint64: v, _ := cmd.Flags().GetUint64(name) f.Set(reflect.ValueOf(v)) case reflect.Float32: v, _ := cmd.Flags().GetFloat32(name) f.Set(reflect.ValueOf(v)) case reflect.Float64: v, _ := cmd.Flags().GetFloat64(name) f.Set(reflect.ValueOf(v)) case reflect.Slice: switch tf.Type.Elem().Kind() { case reflect.String: v, _ := cmd.Flags().GetStringSlice(name) f.Set(reflect.ValueOf(v)) case reflect.Int: v, _ := cmd.Flags().GetIntSlice(name) f.Set(reflect.ValueOf(v)) case reflect.Int32: v, _ := cmd.Flags().GetInt32Slice(name) f.Set(reflect.ValueOf(v)) case reflect.Int64: v, _ := cmd.Flags().GetInt64Slice(name) f.Set(reflect.ValueOf(v)) case reflect.Uint: v, _ := cmd.Flags().GetUintSlice(name) f.Set(reflect.ValueOf(v)) case reflect.Float32: v, _ := cmd.Flags().GetFloat32Slice(name) f.Set(reflect.ValueOf(v)) case reflect.Float64: v, _ := cmd.Flags().GetFloat64Slice(name) f.Set(reflect.ValueOf(v)) default: return fmt.Errorf("type []%s is not supported for field %s", tf.Type.Elem(), tf.Name) } case reflect.Struct: if tf.Type.AssignableTo(reflect.TypeOf(VaultCredential{})) { // cli.VaultCredential v, _ := cmd.Flags().GetString(name) if req && v == "" { log.Fatalf("Flag %s is required but not provided", name) } vk, err := vault.GetVaultKey(v) if err != nil { return fmt.Errorf("Error getting %s from vault: %w", name, err) } f.Set(reflect.ValueOf(VaultCredential{v, vk.Username, vk.Password})) } else { return fmt.Errorf("type %s is not supported for field %s", tf.Type, tf.Name) } default: return fmt.Errorf("type %s is not supported for field %s", tf.Type, tf.Name) } } return nil } func inScope(desired, allowed string) bool { for _, i := range strings.Split(allowed, ",") { if strings.TrimSpace(i) == desired { return true } } return false } func AddFlags(cmd *cobra.Command, cfg interface{}, def interface{}, scope string) error { t := reflect.TypeOf(cfg).Elem() d := reflect.ValueOf(def).Elem() for i := 0; i < t.NumField(); i++ { f := t.Field(i) // Fields with no name are not considered flags name := f.Tag.Get("flag") if name == "" { continue } // Non-matching scopes should not bind here (note root is "") // Scopes can be a comma separated list if !inScope(scope, f.Tag.Get("flag-scope")) { continue } defV := d.FieldByName(f.Name).Interface() help := f.Tag.Get("flag-help") switch f.Type.Kind() { case reflect.Bool: cmd.PersistentFlags().Bool(name, defV.(bool), help) case reflect.String: cmd.PersistentFlags().String(name, defV.(string), help) case reflect.Int: cmd.PersistentFlags().Int(name, defV.(int), help) case reflect.Int32: cmd.PersistentFlags().Int32(name, defV.(int32), help) case reflect.Int64: if f.Type.AssignableTo(reflect.TypeOf(time.Duration(0))) { // time.Duration cmd.PersistentFlags().Duration(name, defV.(time.Duration), help) } else { cmd.PersistentFlags().Int64(name, defV.(int64), help) } case reflect.Uint: cmd.PersistentFlags().Uint(name, defV.(uint), help) case reflect.Uint32: cmd.PersistentFlags().Uint32(name, defV.(uint32), help) case reflect.Uint64: cmd.PersistentFlags().Uint64(name, defV.(uint64), help) case reflect.Float32: cmd.PersistentFlags().Float32(name, defV.(float32), help) case reflect.Float64: cmd.PersistentFlags().Float64(name, defV.(float64), help) case reflect.Slice: switch f.Type.Elem().Kind() { case reflect.String: cmd.PersistentFlags().StringSlice(name, defV.([]string), help) case reflect.Int: cmd.PersistentFlags().IntSlice(name, defV.([]int), help) case reflect.Int32: cmd.PersistentFlags().Int32Slice(name, defV.([]int32), help) case reflect.Int64: cmd.PersistentFlags().Int64Slice(name, defV.([]int64), help) case reflect.Uint: cmd.PersistentFlags().UintSlice(name, defV.([]uint), help) case reflect.Float32: cmd.PersistentFlags().Float32Slice(name, defV.([]float32), help) case reflect.Float64: cmd.PersistentFlags().Float64Slice(name, defV.([]float64), help) default: return fmt.Errorf("type []%s is not supported for field %s", f.Type.Elem(), f.Name) } case reflect.Struct: if f.Type.AssignableTo(reflect.TypeOf(VaultCredential{})) { // cli.VaultCredential cmd.PersistentFlags().String(name, defV.(VaultCredential).Path, help) } else { return fmt.Errorf("type %s is not supported for field %s", f.Type, f.Name) } default: return fmt.Errorf("type %s is not supported for field %s", f.Type, f.Name) } } return nil }