Skip to content

Commit 3c443ef

Browse files
committed
Improve Survey prompt stubber for tests
Both SurveyAsk and SurveyAskOne methods now share the same sets of stubs, making it possible to change which of these methods is used in the implementation without breaking tests. A new method `AskStubber.StubPrompt("<prompt>")` is added as test helper to supersede old Stub and StubOne methods. The new helper matches on prompt messages rather than on field names, enabling tests to be written based on what the user would see rather than coupling to implementation details. The new stubber also allows verifying whether a Select or MultiSelect was rendered with the expected set of options. Furthermore, if a stubbed value is not present among those options, the stubber will panic instead of continuing normally. Stubbed Selects with an int instead of a string target receiver are now transparently handled. The values for Select stubs are always strings in tests, but the stubber will write an int answer if the receiver expects one as a selected index instead of a selected string value. Lastly, this set of changes improves test resiliency since the stubs are now matched based on prompt message (or field name for legacy stubs created with Stub) instead of sequentially, enabling the implementation to reorder the prompts without breaking existing tests.
1 parent 8198cce commit 3c443ef

File tree

2 files changed

+134
-58
lines changed

2 files changed

+134
-58
lines changed

pkg/cmd/pr/shared/survey_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestMetadataSurvey_selectAll(t *testing.T) {
7171
},
7272
{
7373
Name: "milestone",
74-
Value: []string{"(none)"},
74+
Value: "(none)",
7575
},
7676
})
7777

pkg/prompt/stubber.go

Lines changed: 133 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,104 +2,180 @@ package prompt
22

33
import (
44
"fmt"
5-
"reflect"
5+
"strings"
66

77
"github.com/AlecAivazis/survey/v2"
88
"github.com/AlecAivazis/survey/v2/core"
9+
"github.com/cli/cli/v2/pkg/surveyext"
910
)
1011

1112
type AskStubber struct {
12-
Asks [][]*survey.Question
13-
AskOnes []*survey.Prompt
14-
Count int
15-
OneCount int
16-
Stubs [][]*QuestionStub
17-
StubOnes []*PromptStub
13+
stubs []*QuestionStub
1814
}
1915

2016
func InitAskStubber() (*AskStubber, func()) {
2117
origSurveyAsk := SurveyAsk
2218
origSurveyAskOne := SurveyAskOne
2319
as := AskStubber{}
2420

25-
SurveyAskOne = func(p survey.Prompt, response interface{}, opts ...survey.AskOpt) error {
26-
as.AskOnes = append(as.AskOnes, &p)
27-
count := as.OneCount
28-
as.OneCount += 1
29-
if count >= len(as.StubOnes) {
30-
panic(fmt.Sprintf("more asks than stubs. most recent call: %v", p))
31-
}
32-
stubbedPrompt := as.StubOnes[count]
33-
if stubbedPrompt.Default {
34-
// TODO this is failing for basic AskOne invocations with a string result.
35-
defaultValue := reflect.ValueOf(p).Elem().FieldByName("Default")
36-
_ = core.WriteAnswer(response, "", defaultValue)
37-
} else {
38-
_ = core.WriteAnswer(response, "", stubbedPrompt.Value)
21+
answerFromStub := func(p survey.Prompt, fieldName string, response interface{}) error {
22+
var message string
23+
var defaultValue interface{}
24+
var options []string
25+
switch pt := p.(type) {
26+
case *survey.Confirm:
27+
message = pt.Message
28+
defaultValue = pt.Default
29+
case *survey.Input:
30+
message = pt.Message
31+
defaultValue = pt.Default
32+
case *survey.Select:
33+
message = pt.Message
34+
options = pt.Options
35+
case *survey.MultiSelect:
36+
message = pt.Message
37+
options = pt.Options
38+
case *survey.Password:
39+
message = pt.Message
40+
case *surveyext.GhEditor:
41+
message = pt.Message
42+
defaultValue = pt.Default
43+
default:
44+
panic(fmt.Sprintf("prompt type %T is not supported by the stubber", pt))
3945
}
4046

41-
return nil
42-
}
43-
44-
SurveyAsk = func(qs []*survey.Question, response interface{}, opts ...survey.AskOpt) error {
45-
as.Asks = append(as.Asks, qs)
46-
count := as.Count
47-
as.Count += 1
48-
if count >= len(as.Stubs) {
49-
panic(fmt.Sprintf("more asks than stubs. most recent call: %#v", qs))
47+
var stub *QuestionStub
48+
for _, s := range as.stubs {
49+
if !s.matched && (s.message == "" && strings.EqualFold(s.Name, fieldName) || s.message == message) {
50+
stub = s
51+
stub.matched = true
52+
break
53+
}
54+
}
55+
if stub == nil {
56+
panic(fmt.Sprintf("no prompt stub for %q", message))
5057
}
5158

52-
// actually set response
53-
stubbedQuestions := as.Stubs[count]
54-
if len(stubbedQuestions) != len(qs) {
55-
panic(fmt.Sprintf("asked questions: %d; stubbed questions: %d", len(qs), len(stubbedQuestions)))
59+
if len(stub.options) > 0 {
60+
if err := compareOptions(stub.options, options); err != nil {
61+
panic(fmt.Sprintf("options mismatch for %q: %v", message, err))
62+
}
5663
}
57-
for i, sq := range stubbedQuestions {
58-
q := qs[i]
59-
if q.Name != sq.Name {
60-
panic(fmt.Sprintf("stubbed question mismatch: %s != %s", q.Name, sq.Name))
64+
65+
userValue := stub.Value
66+
67+
if stringValue, ok := stub.Value.(string); ok && len(options) > 0 {
68+
foundIndex := -1
69+
for i, o := range options {
70+
if o == stringValue {
71+
foundIndex = i
72+
break
73+
}
6174
}
62-
if sq.Default {
63-
defaultValue := reflect.ValueOf(q.Prompt).Elem().FieldByName("Default")
64-
_ = core.WriteAnswer(response, q.Name, defaultValue)
75+
if foundIndex < 0 {
76+
panic(fmt.Sprintf("answer %q not found in options for %q: %v", stringValue, message, options))
77+
}
78+
userValue = core.OptionAnswer{
79+
Value: stringValue,
80+
Index: foundIndex,
81+
}
82+
}
83+
84+
if stub.Default {
85+
if defaultIndex, ok := defaultValue.(int); ok && len(options) > 0 {
86+
userValue = core.OptionAnswer{
87+
Value: options[defaultIndex],
88+
Index: defaultIndex,
89+
}
90+
} else if defaultValue == nil && len(options) > 0 {
91+
userValue = core.OptionAnswer{
92+
Value: options[0],
93+
Index: 0,
94+
}
6595
} else {
66-
_ = core.WriteAnswer(response, q.Name, sq.Value)
96+
userValue = defaultValue
6797
}
6898
}
6999

100+
if err := core.WriteAnswer(response, fieldName, userValue); err != nil {
101+
return fmt.Errorf("AskStubber failed writing the answer for field %q: %w", fieldName, err)
102+
}
103+
return nil
104+
}
105+
106+
SurveyAskOne = func(p survey.Prompt, response interface{}, opts ...survey.AskOpt) error {
107+
return answerFromStub(p, "", response)
108+
}
109+
110+
SurveyAsk = func(qs []*survey.Question, response interface{}, opts ...survey.AskOpt) error {
111+
for _, q := range qs {
112+
if err := answerFromStub(q.Prompt, q.Name, response); err != nil {
113+
return err
114+
}
115+
}
70116
return nil
71117
}
118+
72119
teardown := func() {
73120
SurveyAsk = origSurveyAsk
74121
SurveyAskOne = origSurveyAskOne
75122
}
76123
return &as, teardown
77124
}
78125

79-
type PromptStub struct {
80-
Value interface{}
81-
Default bool
82-
}
83-
84126
type QuestionStub struct {
85127
Name string
86128
Value interface{}
87129
Default bool
130+
131+
matched bool
132+
message string
133+
options []string
88134
}
89135

90-
func (as *AskStubber) StubOne(value interface{}) {
91-
as.StubOnes = append(as.StubOnes, &PromptStub{
92-
Value: value,
93-
})
136+
// AssertOptions asserts the options presented to the user in Selects and MultiSelects.
137+
func (s *QuestionStub) AssertOptions(opts []string) *QuestionStub {
138+
s.options = opts
139+
return s
140+
}
141+
142+
// AnswerWith defines an answer for the given stub.
143+
func (s *QuestionStub) AnswerWith(v interface{}) *QuestionStub {
144+
s.Value = v
145+
return s
146+
}
147+
148+
// AnswerDefault marks the current stub to be answered with the default value for the prompt question.
149+
func (s *QuestionStub) AnswerDefault() *QuestionStub {
150+
s.Default = true
151+
return s
94152
}
95153

96-
func (as *AskStubber) StubOneDefault() {
97-
as.StubOnes = append(as.StubOnes, &PromptStub{
98-
Default: true,
99-
})
154+
// Deprecated: use StubPrompt
155+
func (as *AskStubber) StubOne(value interface{}) {
156+
as.Stub([]*QuestionStub{{Value: value}})
100157
}
101158

159+
// Deprecated: use StubPrompt
102160
func (as *AskStubber) Stub(stubbedQuestions []*QuestionStub) {
103-
// A call to .Ask takes a list of questions; a stub is then a list of questions in the same order.
104-
as.Stubs = append(as.Stubs, stubbedQuestions)
161+
as.stubs = append(as.stubs, stubbedQuestions...)
162+
}
163+
164+
// StubPrompt records a stub for an interactive prompt matched by its message.
165+
func (as *AskStubber) StubPrompt(msg string) *QuestionStub {
166+
stub := &QuestionStub{message: msg}
167+
as.stubs = append(as.stubs, stub)
168+
return stub
169+
}
170+
171+
func compareOptions(expected, got []string) error {
172+
if len(expected) != len(got) {
173+
return fmt.Errorf("expected %v, got %v (length mismatch)", expected, got)
174+
}
175+
for i, v := range expected {
176+
if v != got[i] {
177+
return fmt.Errorf("expected %v, got %v", expected, got)
178+
}
179+
}
180+
return nil
105181
}

0 commit comments

Comments
 (0)