@@ -2,104 +2,180 @@ package prompt
22
33import (
44 "fmt"
5- "reflect "
5+ "strings "
66
77 "github.com/AlecAivazis/survey/v2"
88 "github.com/AlecAivazis/survey/v2/core"
9+ "github.com/cli/cli/v2/pkg/surveyext"
910)
1011
1112type AskStubber struct {
12- Asks [][]* survey.Question
13- AskOnes []* survey.Prompt
14- Count int
15- OneCount int
16- Stubs [][]* QuestionStub
17- StubOnes []* PromptStub
13+ stubs []* QuestionStub
1814}
1915
2016func InitAskStubber () (* AskStubber , func ()) {
2117 origSurveyAsk := SurveyAsk
2218 origSurveyAskOne := SurveyAskOne
2319 as := AskStubber {}
2420
25- SurveyAskOne = func (p survey.Prompt , response interface {}, opts ... survey.AskOpt ) error {
26- as .AskOnes = append (as .AskOnes , & p )
27- count := as .OneCount
28- as .OneCount += 1
29- if count >= len (as .StubOnes ) {
30- panic (fmt .Sprintf ("more asks than stubs. most recent call: %v" , p ))
31- }
32- stubbedPrompt := as .StubOnes [count ]
33- if stubbedPrompt .Default {
34- // TODO this is failing for basic AskOne invocations with a string result.
35- defaultValue := reflect .ValueOf (p ).Elem ().FieldByName ("Default" )
36- _ = core .WriteAnswer (response , "" , defaultValue )
37- } else {
38- _ = core .WriteAnswer (response , "" , stubbedPrompt .Value )
21+ answerFromStub := func (p survey.Prompt , fieldName string , response interface {}) error {
22+ var message string
23+ var defaultValue interface {}
24+ var options []string
25+ switch pt := p .(type ) {
26+ case * survey.Confirm :
27+ message = pt .Message
28+ defaultValue = pt .Default
29+ case * survey.Input :
30+ message = pt .Message
31+ defaultValue = pt .Default
32+ case * survey.Select :
33+ message = pt .Message
34+ options = pt .Options
35+ case * survey.MultiSelect :
36+ message = pt .Message
37+ options = pt .Options
38+ case * survey.Password :
39+ message = pt .Message
40+ case * surveyext.GhEditor :
41+ message = pt .Message
42+ defaultValue = pt .Default
43+ default :
44+ panic (fmt .Sprintf ("prompt type %T is not supported by the stubber" , pt ))
3945 }
4046
41- return nil
42- }
43-
44- SurveyAsk = func (qs []* survey.Question , response interface {}, opts ... survey.AskOpt ) error {
45- as .Asks = append (as .Asks , qs )
46- count := as .Count
47- as .Count += 1
48- if count >= len (as .Stubs ) {
49- panic (fmt .Sprintf ("more asks than stubs. most recent call: %#v" , qs ))
47+ var stub * QuestionStub
48+ for _ , s := range as .stubs {
49+ if ! s .matched && (s .message == "" && strings .EqualFold (s .Name , fieldName ) || s .message == message ) {
50+ stub = s
51+ stub .matched = true
52+ break
53+ }
54+ }
55+ if stub == nil {
56+ panic (fmt .Sprintf ("no prompt stub for %q" , message ))
5057 }
5158
52- // actually set response
53- stubbedQuestions := as . Stubs [ count ]
54- if len ( stubbedQuestions ) != len ( qs ) {
55- panic ( fmt . Sprintf ( "asked questions: %d; stubbed questions: %d" , len ( qs ), len ( stubbedQuestions )))
59+ if len ( stub . options ) > 0 {
60+ if err := compareOptions ( stub . options , options ); err != nil {
61+ panic ( fmt . Sprintf ( "options mismatch for %q: %v" , message , err ))
62+ }
5663 }
57- for i , sq := range stubbedQuestions {
58- q := qs [i ]
59- if q .Name != sq .Name {
60- panic (fmt .Sprintf ("stubbed question mismatch: %s != %s" , q .Name , sq .Name ))
64+
65+ userValue := stub .Value
66+
67+ if stringValue , ok := stub .Value .(string ); ok && len (options ) > 0 {
68+ foundIndex := - 1
69+ for i , o := range options {
70+ if o == stringValue {
71+ foundIndex = i
72+ break
73+ }
6174 }
62- if sq .Default {
63- defaultValue := reflect .ValueOf (q .Prompt ).Elem ().FieldByName ("Default" )
64- _ = core .WriteAnswer (response , q .Name , defaultValue )
75+ if foundIndex < 0 {
76+ panic (fmt .Sprintf ("answer %q not found in options for %q: %v" , stringValue , message , options ))
77+ }
78+ userValue = core.OptionAnswer {
79+ Value : stringValue ,
80+ Index : foundIndex ,
81+ }
82+ }
83+
84+ if stub .Default {
85+ if defaultIndex , ok := defaultValue .(int ); ok && len (options ) > 0 {
86+ userValue = core.OptionAnswer {
87+ Value : options [defaultIndex ],
88+ Index : defaultIndex ,
89+ }
90+ } else if defaultValue == nil && len (options ) > 0 {
91+ userValue = core.OptionAnswer {
92+ Value : options [0 ],
93+ Index : 0 ,
94+ }
6595 } else {
66- _ = core . WriteAnswer ( response , q . Name , sq . Value )
96+ userValue = defaultValue
6797 }
6898 }
6999
100+ if err := core .WriteAnswer (response , fieldName , userValue ); err != nil {
101+ return fmt .Errorf ("AskStubber failed writing the answer for field %q: %w" , fieldName , err )
102+ }
103+ return nil
104+ }
105+
106+ SurveyAskOne = func (p survey.Prompt , response interface {}, opts ... survey.AskOpt ) error {
107+ return answerFromStub (p , "" , response )
108+ }
109+
110+ SurveyAsk = func (qs []* survey.Question , response interface {}, opts ... survey.AskOpt ) error {
111+ for _ , q := range qs {
112+ if err := answerFromStub (q .Prompt , q .Name , response ); err != nil {
113+ return err
114+ }
115+ }
70116 return nil
71117 }
118+
72119 teardown := func () {
73120 SurveyAsk = origSurveyAsk
74121 SurveyAskOne = origSurveyAskOne
75122 }
76123 return & as , teardown
77124}
78125
79- type PromptStub struct {
80- Value interface {}
81- Default bool
82- }
83-
84126type QuestionStub struct {
85127 Name string
86128 Value interface {}
87129 Default bool
130+
131+ matched bool
132+ message string
133+ options []string
88134}
89135
90- func (as * AskStubber ) StubOne (value interface {}) {
91- as .StubOnes = append (as .StubOnes , & PromptStub {
92- Value : value ,
93- })
136+ // AssertOptions asserts the options presented to the user in Selects and MultiSelects.
137+ func (s * QuestionStub ) AssertOptions (opts []string ) * QuestionStub {
138+ s .options = opts
139+ return s
140+ }
141+
142+ // AnswerWith defines an answer for the given stub.
143+ func (s * QuestionStub ) AnswerWith (v interface {}) * QuestionStub {
144+ s .Value = v
145+ return s
146+ }
147+
148+ // AnswerDefault marks the current stub to be answered with the default value for the prompt question.
149+ func (s * QuestionStub ) AnswerDefault () * QuestionStub {
150+ s .Default = true
151+ return s
94152}
95153
96- func (as * AskStubber ) StubOneDefault () {
97- as .StubOnes = append (as .StubOnes , & PromptStub {
98- Default : true ,
99- })
154+ // Deprecated: use StubPrompt
155+ func (as * AskStubber ) StubOne (value interface {}) {
156+ as .Stub ([]* QuestionStub {{Value : value }})
100157}
101158
159+ // Deprecated: use StubPrompt
102160func (as * AskStubber ) Stub (stubbedQuestions []* QuestionStub ) {
103- // A call to .Ask takes a list of questions; a stub is then a list of questions in the same order.
104- as .Stubs = append (as .Stubs , stubbedQuestions )
161+ as .stubs = append (as .stubs , stubbedQuestions ... )
162+ }
163+
164+ // StubPrompt records a stub for an interactive prompt matched by its message.
165+ func (as * AskStubber ) StubPrompt (msg string ) * QuestionStub {
166+ stub := & QuestionStub {message : msg }
167+ as .stubs = append (as .stubs , stub )
168+ return stub
169+ }
170+
171+ func compareOptions (expected , got []string ) error {
172+ if len (expected ) != len (got ) {
173+ return fmt .Errorf ("expected %v, got %v (length mismatch)" , expected , got )
174+ }
175+ for i , v := range expected {
176+ if v != got [i ] {
177+ return fmt .Errorf ("expected %v, got %v" , expected , got )
178+ }
179+ }
180+ return nil
105181}
0 commit comments