Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shorthand combination edge case #2189

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 59 additions & 17 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -643,13 +643,35 @@ func shortHasNoOptDefVal(name string, fs *flag.FlagSet) bool {
return flag.NoOptDefVal != ""
}

func stripFlags(args []string, c *Command) []string {
func shorthandCombinationNeedsNextArg(combination string, flags *flag.FlagSet) bool {
lastPos := len(combination) - 1
for i, shorthand := range combination {
if !shortHasNoOptDefVal(string(shorthand), flags) {
// This shorthand needs a value.
//
// If we're at the end of the shorthand combination, this means that the
// value for the shorthand is given in the next argument. (e.g. '-xyzf arg',
// where -x, -y, -z are boolean flags, and -f is a flag that needs a value).
//
// Otherwise, if the shorthand combination doesn't end here, this means that the value
// for the shorthand is given in the same argument, meaning we don't have to consume the
// next one. (e.g. '-xyzfarg', where -x, -y, -z are boolean flags, and -f is a flag that
// needs a value).
return i == lastPos
}
}

return false
}

func stripFlags(args []string, c *Command) ([]string, []string) {
if len(args) == 0 {
return args
return args, nil
}
c.mergePersistentFlags()

commands := []string{}
flagsThatConsumeNextArg := []string{} // We use this to avoid repeating the same lengthy logic for parsing shorthand combinations in argsMinusFirstX
flags := c.Flags()

Loop:
Expand All @@ -665,31 +687,48 @@ Loop:
// delete arg from args.
fallthrough // (do the same as below)
case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s)
// If '-f arg' then
// delete 'arg' from args or break the loop if len(args) <= 1.
if len(args) <= 1 {
break Loop
} else {
args = args[1:]
continue
}
case strings.HasPrefix(s, "-") && !strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && len(s) > 2:
shorthandCombination := s[1:] // Skip leading "-"
if shorthandCombinationNeedsNextArg(shorthandCombination, flags) {
flagsThatConsumeNextArg = append(flagsThatConsumeNextArg, s)
if len(args) <= 1 {
break Loop
} else {
args = args[1:]
}
}
case s != "" && !strings.HasPrefix(s, "-"):
commands = append(commands, s)
}
}

return commands
return commands, flagsThatConsumeNextArg
}

// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like
// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
// Special care needs to be taken not to remove a flag value.
func (c *Command) argsMinusFirstX(args []string, x string) []string {
func (c *Command) argsMinusFirstX(args, flagsThatConsumeNextArg []string, x string) []string {
if len(args) == 0 {
return args
}
c.mergePersistentFlags()
flags := c.Flags()

consumesNextArg := func(flag string) bool {
for _, f := range flagsThatConsumeNextArg {
if flag == f {
return true
}
}
return false
}

Loop:
for pos := 0; pos < len(args); pos++ {
Expand All @@ -698,13 +737,8 @@ Loop:
case s == "--":
// -- means we have reached the end of the parseable args. Break out of the loop now.
break Loop
case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags):
fallthrough
case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
// This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip
// over the next arg, because that is the value of this flag.
case consumesNextArg(s):
pos++
continue
case !strings.HasPrefix(s, "-"):
// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
// return the args, excluding the one at this position.
Expand All @@ -730,22 +764,23 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
var innerfind func(*Command, []string) (*Command, []string)

innerfind = func(c *Command, innerArgs []string) (*Command, []string) {
argsWOflags := stripFlags(innerArgs, c)
argsWOflags, flagsThatConsumeNextArg := stripFlags(innerArgs, c)
if len(argsWOflags) == 0 {
return c, innerArgs
}
nextSubCmd := argsWOflags[0]

cmd := c.findNext(nextSubCmd)
if cmd != nil {
return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd))
return innerfind(cmd, c.argsMinusFirstX(innerArgs, flagsThatConsumeNextArg, nextSubCmd))
}
return c, innerArgs
}

commandFound, a := innerfind(c, args)
if commandFound.Args == nil {
return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound))
argsWOflags, _ := stripFlags(a, commandFound)
return commandFound, a, legacyArgs(commandFound, argsWOflags)
}
return commandFound, a, nil
}
Expand Down Expand Up @@ -812,9 +847,16 @@ func (c *Command) Traverse(args []string) (*Command, []string, error) {
inFlag = false
flags = append(flags, arg)
continue
// A flag without a value, or with an `=` separated value
// A flag with an `=` separated value, or a shorthand combination, possibly with a value
case isFlagArg(arg):
flags = append(flags, arg)

if strings.HasPrefix(arg, "--") || strings.Contains(arg, "=") || len(arg) <= 2 {
continue // Not a shorthand combination, so nothing more to do.
}

shorthandCombination := arg[1:] // Skip leading "-"
inFlag = shorthandCombinationNeedsNextArg(shorthandCombination, c.Flags())
continue
}

Expand Down
120 changes: 117 additions & 3 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,30 @@ func TestStripFlags(t *testing.T) {
[]string{"-p", "bar"},
[]string{"bar"},
},
{
[]string{"-s", "value", "bar"},
[]string{"bar"},
},
{
[]string{"-s=value", "bar"},
[]string{"bar"},
},
{
[]string{"-svalue", "bar"},
[]string{"bar"},
},
{
[]string{"-ps", "value", "bar"},
[]string{"bar"},
},
{
[]string{"-ps=value", "bar"},
[]string{"bar"},
},
{
[]string{"-psvalue", "bar"},
[]string{"bar"},
},
}

c := &Command{Use: "c", Run: emptyRun}
Expand All @@ -702,7 +726,7 @@ func TestStripFlags(t *testing.T) {
c.Flags().BoolP("bool", "b", false, "")

for i, test := range tests {
got := stripFlags(test.input, c)
got, _ := stripFlags(test.input, c)
if !reflect.DeepEqual(test.output, got) {
t.Errorf("(%v) Expected: %v, got: %v", i, test.output, got)
}
Expand Down Expand Up @@ -2229,12 +2253,68 @@ func TestTraverseWithParentFlags(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 && args[0] != "--add" {
if len(args) != 1 || args[0] != "--int" {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original condition seems wrong. --add wasn't mentioned anywhere else in the repo and from how I understand the test, || makes sense here, not &&.

t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
}

func TestTraverseWithShorthandCombinationInParentFlags(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
stringVal := rootCmd.Flags().StringP("str", "s", "", "")
boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")

childCmd := &Command{Use: "child"}
childCmd.Flags().Int("int", -1, "")

rootCmd.AddCommand(childCmd)

c, args, err := rootCmd.Traverse([]string{"-bs", "ok", "child", "--int"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
if *stringVal != "ok" {
t.Errorf("Expected -s to be set to: %s, got: %s", "ok", *stringVal)
}
if !*boolVal {
t.Errorf("Expected -b to be set")
}
}

func TestTraverseWithArgumentIdenticalToCommandName(t *testing.T) {
rootCmd := &Command{Use: "root", TraverseChildren: true}
stringVal := rootCmd.Flags().StringP("str", "s", "", "")
boolVal := rootCmd.Flags().BoolP("bool", "b", false, "")

childCmd := &Command{Use: "child"}
childCmd.Flags().Int("int", -1, "")

rootCmd.AddCommand(childCmd)

c, args, err := rootCmd.Traverse([]string{"-bs", "child", "child", "--int"})
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 || args[0] != "--int" {
t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
t.Errorf("Expected command: %q, got: %q", childCmd.Name(), c.Name())
}
if *stringVal != "child" {
t.Errorf("Expected -s to be set to: %s, got: %s", "child", *stringVal)
}
if !*boolVal {
t.Errorf("Expected -b to be set")
}
}

func TestTraverseNoParentFlags(t *testing.T) {
Expand Down Expand Up @@ -2288,7 +2368,7 @@ func TestTraverseWithBadChildFlag(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(args) != 1 && args[0] != "--str" {
if len(args) != 1 || args[0] != "--str" {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

t.Errorf("Wrong args: %v", args)
}
if c.Name() != childCmd.Name() {
Expand Down Expand Up @@ -2688,11 +2768,13 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) {

func TestFind(t *testing.T) {
var foo, bar string
var persist bool
root := &Command{
Use: "root",
}
root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "")
root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "")
root.PersistentFlags().BoolVarP(&persist, "persist", "p", false, "")

child := &Command{
Use: "child",
Expand Down Expand Up @@ -2755,6 +2837,38 @@ func TestFind(t *testing.T) {
[]string{"--foo", "child", "--bar", "something", "child"},
[]string{"--foo", "child", "--bar", "something"},
},
{
[]string{"-f", "value", "child"},
[]string{"-f", "value"},
},
{
[]string{"-f=value", "child"},
[]string{"-f=value"},
},
{
[]string{"-fvalue", "child"},
[]string{"-fvalue"},
},
{
[]string{"-pf", "value", "child"},
[]string{"-pf", "value"},
},
{
[]string{"-pf=value", "child"},
[]string{"-pf=value"},
},
{
[]string{"-pfvalue", "child"},
[]string{"-pfvalue"},
},
{
[]string{"-pf", "child", "child"},
[]string{"-pf", "child"},
},
{
[]string{"-pf", "child", "-pb", "something", "child"},
[]string{"-pf", "child", "-pb", "something"},
},
}

for _, tc := range testCases {
Expand Down