diff --git a/cmd/config/config.go b/cmd/config/config.go index 858b7b04..9ed15df3 100644 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "errors" "fmt" "os" @@ -54,9 +55,44 @@ func run() func(*cobra.Command, []string) error { fmt.Fprint(cmd.OutOrStdout(), string(configJSON)+"\n") case viper.GetBool(SetFlag): - fmt.Fprintln(cmd.OutOrStdout(), "called --set flag") - case viper.GetBool(UnsetFlag): - fmt.Fprintln(cmd.OutOrStdout(), "called --unset flag") + // flag needs two arguments: a key and value + if len(args)%2 != 0 { + return errors.New("flag needs an argument: --set") + } + + rawConfig, v, err := getRawConfig() + if err != nil { + return err + } + + // add arg pairs to config + // where each argument is --set arg1 val1 --set arg2 val2 + for i, a := range args { + if i%2 == 0 { + rawConfig[a] = struct{}{} + } else { + rawConfig[args[i-1]] = a + } + } + + setKeyFn := func(key string, value interface{}, v *viper.Viper) { + v.Set(key, value) + } + + return writeConfig(config.NewConfig(rawConfig), v, setKeyFn) + case viper.IsSet(UnsetFlag): + config, v, err := getConfig() + if err != nil { + return err + } + + unsetKeyFn := func(key string, value interface{}, v *viper.Viper) { + if key != viper.GetString("unset") { + v.Set(key, value) + } + } + + return writeConfig(config, v, unsetKeyFn) } return nil @@ -84,6 +120,27 @@ func getConfig() (config.ConfigFile, *viper.Viper, error) { return c, v, nil } +// getRawConfig builds a map of the values in the config file. +func getRawConfig() (map[string]interface{}, *viper.Viper, error) { + v, err := getViperWithConfigFile() + if err != nil { + return nil, nil, err + } + + data, err := os.ReadFile(v.ConfigFileUsed()) + if err != nil { + return nil, nil, err + } + + config := make(map[string]interface{}, 0) + err = yaml.Unmarshal([]byte(data), &config) + if err != nil { + return nil, nil, err + } + + return config, v, nil +} + // getViperWithConfigFile ensures the viper instance has a config file written to the filesystem. // We want to write the file when someone runs a config command. func getViperWithConfigFile() (*viper.Viper, error) { @@ -105,3 +162,36 @@ func getViperWithConfigFile() (*viper.Viper, error) { return v, nil } + +// writeConfig writes the values in config to the config file based on the filter function. +func writeConfig( + conf config.ConfigFile, + v *viper.Viper, + filterFn func(key string, value interface{}, v *viper.Viper), +) error { + // create a new viper instance since the existing one has the command name and flags already set, + // and these will get written to the config file. + newViper := viper.New() + newViper.SetConfigFile(v.ConfigFileUsed()) + + configYAML, err := yaml.Marshal(conf) + if err != nil { + return err + } + rawConfig := make(map[string]interface{}) + err = yaml.Unmarshal(configYAML, &rawConfig) + if err != nil { + return err + } + + for key, value := range rawConfig { + filterFn(key, value, newViper) + } + + err = newViper.WriteConfig() + if err != nil { + return err + } + + return nil +} diff --git a/internal/config/config.go b/internal/config/config.go index f70286b2..13180881 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -43,6 +43,7 @@ func GetConfigPath() string { } configPath = filepath.Join(home, ".config") } + return filepath.Join(configPath, "ldcli") }