Skip to content

Commit 774ade9

Browse files
committed
Add a helper to safely override env variables with in tests
1 parent 8235140 commit 774ade9

File tree

6 files changed

+104
-89
lines changed

6 files changed

+104
-89
lines changed

git/git_test.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ import (
66
"testing"
77

88
"github.com/cli/cli/internal/run"
9+
"github.com/cli/cli/pkg/env"
910
)
1011

1112
func setGitDir(t *testing.T, dir string) {
12-
// TODO: also set XDG_CONFIG_HOME, GIT_CONFIG_NOSYSTEM
13-
old_GIT_DIR := os.Getenv("GIT_DIR")
14-
os.Setenv("GIT_DIR", dir)
15-
t.Cleanup(func() {
16-
os.Setenv("GIT_DIR", old_GIT_DIR)
13+
wd, _ := os.Getwd()
14+
reset := env.WithEnv(map[string]string{
15+
"GIT_DIR": dir,
16+
"XDG_CONFIG_HOME": wd,
17+
"GIT_CONFIG_NOSYSTEM": "1",
1718
})
19+
t.Cleanup(reset)
1820
}
1921

2022
func TestLastCommit(t *testing.T) {

internal/config/from_env_test.go

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,14 @@
11
package config
22

33
import (
4-
"os"
54
"testing"
65

76
"github.com/MakeNowJust/heredoc"
7+
"github.com/cli/cli/pkg/env"
88
"github.com/stretchr/testify/assert"
99
)
1010

1111
func TestInheritEnv(t *testing.T) {
12-
orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
13-
orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN")
14-
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
15-
orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN")
16-
t.Cleanup(func() {
17-
os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN)
18-
os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN)
19-
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
20-
os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN)
21-
})
22-
2312
type wants struct {
2413
hosts []string
2514
token string
@@ -260,10 +249,12 @@ func TestInheritEnv(t *testing.T) {
260249
}
261250
for _, tt := range tests {
262251
t.Run(tt.name, func(t *testing.T) {
263-
os.Setenv("GITHUB_TOKEN", tt.GITHUB_TOKEN)
264-
os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN)
265-
os.Setenv("GH_TOKEN", tt.GH_TOKEN)
266-
os.Setenv("GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN)
252+
t.Cleanup(env.WithEnv(map[string]string{
253+
"GITHUB_TOKEN": tt.GITHUB_TOKEN,
254+
"GITHUB_ENTERPRISE_TOKEN": tt.GITHUB_ENTERPRISE_TOKEN,
255+
"GH_TOKEN": tt.GH_TOKEN,
256+
"GH_ENTERPRISE_TOKEN": tt.GH_ENTERPRISE_TOKEN,
257+
}))
267258

268259
baseCfg := NewFromString(tt.baseConfig)
269260
cfg := InheritEnv(baseCfg)
@@ -287,17 +278,6 @@ func TestInheritEnv(t *testing.T) {
287278
}
288279

289280
func TestAuthTokenProvidedFromEnv(t *testing.T) {
290-
orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
291-
orig_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN")
292-
orig_GH_TOKEN := os.Getenv("GH_TOKEN")
293-
orig_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN")
294-
t.Cleanup(func() {
295-
os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN)
296-
os.Setenv("GITHUB_ENTERPRISE_TOKEN", orig_GITHUB_ENTERPRISE_TOKEN)
297-
os.Setenv("GH_TOKEN", orig_GH_TOKEN)
298-
os.Setenv("GH_ENTERPRISE_TOKEN", orig_GH_ENTERPRISE_TOKEN)
299-
})
300-
301281
tests := []struct {
302282
name string
303283
GITHUB_TOKEN string
@@ -334,10 +314,12 @@ func TestAuthTokenProvidedFromEnv(t *testing.T) {
334314

335315
for _, tt := range tests {
336316
t.Run(tt.name, func(t *testing.T) {
337-
os.Setenv("GITHUB_TOKEN", tt.GITHUB_TOKEN)
338-
os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.GITHUB_ENTERPRISE_TOKEN)
339-
os.Setenv("GH_TOKEN", tt.GH_TOKEN)
340-
os.Setenv("GH_ENTERPRISE_TOKEN", tt.GH_ENTERPRISE_TOKEN)
317+
t.Cleanup(env.WithEnv(map[string]string{
318+
"GITHUB_TOKEN": tt.GITHUB_TOKEN,
319+
"GITHUB_ENTERPRISE_TOKEN": tt.GITHUB_ENTERPRISE_TOKEN,
320+
"GH_TOKEN": tt.GH_TOKEN,
321+
"GH_ENTERPRISE_TOKEN": tt.GH_ENTERPRISE_TOKEN,
322+
}))
341323
assert.Equal(t, tt.provided, AuthTokenProvidedFromEnv())
342324
})
343325
}

pkg/cmd/auth/login/login_test.go

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@ package login
33
import (
44
"bytes"
55
"net/http"
6-
"os"
76
"regexp"
87
"testing"
98

109
"github.com/MakeNowJust/heredoc"
1110
"github.com/cli/cli/internal/config"
1211
"github.com/cli/cli/internal/run"
1312
"github.com/cli/cli/pkg/cmdutil"
13+
"github.com/cli/cli/pkg/env"
1414
"github.com/cli/cli/pkg/httpmock"
1515
"github.com/cli/cli/pkg/iostreams"
1616
"github.com/cli/cli/pkg/prompt"
@@ -204,6 +204,12 @@ func Test_loginRun_nontty(t *testing.T) {
204204
Hostname: "github.com",
205205
Token: "abc123",
206206
},
207+
env: map[string]string{
208+
"GH_TOKEN": "",
209+
"GITHUB_TOKEN": "",
210+
"GH_ENTERPRISE_TOKEN": "",
211+
"GITHUB_ENTERPRISE_TOKEN": "",
212+
},
207213
httpStubs: func(reg *httpmock.Registry) {
208214
reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,read:org"))
209215
},
@@ -215,6 +221,12 @@ func Test_loginRun_nontty(t *testing.T) {
215221
Hostname: "albert.wesker",
216222
Token: "abc123",
217223
},
224+
env: map[string]string{
225+
"GH_TOKEN": "",
226+
"GITHUB_TOKEN": "",
227+
"GH_ENTERPRISE_TOKEN": "",
228+
"GITHUB_ENTERPRISE_TOKEN": "",
229+
},
218230
httpStubs: func(reg *httpmock.Registry) {
219231
reg.Register(httpmock.REST("GET", "api/v3/"), httpmock.ScopesResponder("repo,read:org"))
220232
},
@@ -226,6 +238,12 @@ func Test_loginRun_nontty(t *testing.T) {
226238
Hostname: "github.com",
227239
Token: "abc456",
228240
},
241+
env: map[string]string{
242+
"GH_TOKEN": "",
243+
"GITHUB_TOKEN": "",
244+
"GH_ENTERPRISE_TOKEN": "",
245+
"GITHUB_ENTERPRISE_TOKEN": "",
246+
},
229247
httpStubs: func(reg *httpmock.Registry) {
230248
reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("read:org"))
231249
},
@@ -237,6 +255,12 @@ func Test_loginRun_nontty(t *testing.T) {
237255
Hostname: "github.com",
238256
Token: "abc456",
239257
},
258+
env: map[string]string{
259+
"GH_TOKEN": "",
260+
"GITHUB_TOKEN": "",
261+
"GH_ENTERPRISE_TOKEN": "",
262+
"GITHUB_ENTERPRISE_TOKEN": "",
263+
},
240264
httpStubs: func(reg *httpmock.Registry) {
241265
reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo"))
242266
},
@@ -248,6 +272,12 @@ func Test_loginRun_nontty(t *testing.T) {
248272
Hostname: "github.com",
249273
Token: "abc456",
250274
},
275+
env: map[string]string{
276+
"GH_TOKEN": "",
277+
"GITHUB_TOKEN": "",
278+
"GH_ENTERPRISE_TOKEN": "",
279+
"GITHUB_ENTERPRISE_TOKEN": "",
280+
},
251281
httpStubs: func(reg *httpmock.Registry) {
252282
reg.Register(httpmock.REST("GET", ""), httpmock.ScopesResponder("repo,admin:org"))
253283
},
@@ -260,7 +290,10 @@ func Test_loginRun_nontty(t *testing.T) {
260290
Token: "abc456",
261291
},
262292
env: map[string]string{
263-
"GH_TOKEN": "value_from_env",
293+
"GH_TOKEN": "ENVTOKEN",
294+
"GITHUB_TOKEN": "",
295+
"GH_ENTERPRISE_TOKEN": "",
296+
"GITHUB_ENTERPRISE_TOKEN": "",
264297
},
265298
wantErr: "SilentError",
266299
wantStderr: heredoc.Doc(`
@@ -275,7 +308,10 @@ func Test_loginRun_nontty(t *testing.T) {
275308
Token: "abc456",
276309
},
277310
env: map[string]string{
278-
"GH_ENTERPRISE_TOKEN": "value_from_env",
311+
"GH_TOKEN": "",
312+
"GITHUB_TOKEN": "",
313+
"GH_ENTERPRISE_TOKEN": "ENVTOKEN",
314+
"GITHUB_ENTERPRISE_TOKEN": "",
279315
},
280316
wantErr: "SilentError",
281317
wantStderr: heredoc.Doc(`
@@ -303,20 +339,7 @@ func Test_loginRun_nontty(t *testing.T) {
303339
return &http.Client{Transport: reg}, nil
304340
}
305341

306-
old_GH_TOKEN := os.Getenv("GH_TOKEN")
307-
os.Setenv("GH_TOKEN", tt.env["GH_TOKEN"])
308-
old_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
309-
os.Setenv("GITHUB_TOKEN", tt.env["GITHUB_TOKEN"])
310-
old_GH_ENTERPRISE_TOKEN := os.Getenv("GH_ENTERPRISE_TOKEN")
311-
os.Setenv("GH_ENTERPRISE_TOKEN", tt.env["GH_ENTERPRISE_TOKEN"])
312-
old_GITHUB_ENTERPRISE_TOKEN := os.Getenv("GITHUB_ENTERPRISE_TOKEN")
313-
os.Setenv("GITHUB_ENTERPRISE_TOKEN", tt.env["GITHUB_ENTERPRISE_TOKEN"])
314-
defer func() {
315-
os.Setenv("GH_TOKEN", old_GH_TOKEN)
316-
os.Setenv("GITHUB_TOKEN", old_GITHUB_TOKEN)
317-
os.Setenv("GH_ENTERPRISE_TOKEN", old_GH_ENTERPRISE_TOKEN)
318-
os.Setenv("GITHUB_ENTERPRISE_TOKEN", old_GITHUB_ENTERPRISE_TOKEN)
319-
}()
342+
t.Cleanup(env.WithEnv(tt.env))
320343

321344
if tt.httpStubs != nil {
322345
tt.httpStubs(reg)

pkg/cmdutil/auth_check_test.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
package cmdutil
22

33
import (
4-
"os"
54
"testing"
65

76
"github.com/cli/cli/internal/config"
7+
"github.com/cli/cli/pkg/env"
88
"github.com/stretchr/testify/assert"
99
)
1010

1111
func Test_CheckAuth(t *testing.T) {
12-
orig_GITHUB_TOKEN := os.Getenv("GITHUB_TOKEN")
13-
t.Cleanup(func() {
14-
os.Setenv("GITHUB_TOKEN", orig_GITHUB_TOKEN)
15-
})
16-
1712
tests := []struct {
1813
name string
1914
cfg func(config.Config)
@@ -51,11 +46,13 @@ func Test_CheckAuth(t *testing.T) {
5146

5247
for _, tt := range tests {
5348
t.Run(tt.name, func(t *testing.T) {
49+
tokenValue := ""
5450
if tt.envToken {
55-
os.Setenv("GITHUB_TOKEN", "TOKEN")
56-
} else {
57-
os.Setenv("GITHUB_TOKEN", "")
51+
tokenValue = "TOKEN"
5852
}
53+
t.Cleanup(env.WithEnv(map[string]string{
54+
"GITHUB_TOKEN": tokenValue,
55+
}))
5956

6057
cfg := config.NewBlankConfig()
6158
tt.cfg(cfg)

pkg/env/with_env.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package env
2+
3+
import "os"
4+
5+
// WithEnv changes environment variables and returns a function to restore them to their original values.
6+
func WithEnv(vars map[string]string) func() {
7+
originalValues := map[string]*string{}
8+
for name, value := range vars {
9+
if oldValue, ok := os.LookupEnv(name); ok {
10+
originalValues[name] = &oldValue
11+
} else {
12+
originalValues[name] = nil
13+
}
14+
os.Setenv(name, value)
15+
}
16+
17+
return func() {
18+
for name, oldValue := range originalValues {
19+
if oldValue == nil {
20+
os.Unsetenv(name)
21+
} else {
22+
os.Setenv(name, *oldValue)
23+
}
24+
}
25+
}
26+
}

pkg/iostreams/color_test.go

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,12 @@
11
package iostreams
22

33
import (
4-
"os"
54
"testing"
5+
6+
"github.com/cli/cli/pkg/env"
67
)
78

89
func TestEnvColorDisabled(t *testing.T) {
9-
orig_NO_COLOR := os.Getenv("NO_COLOR")
10-
orig_CLICOLOR := os.Getenv("CLICOLOR")
11-
orig_CLICOLOR_FORCE := os.Getenv("CLICOLOR_FORCE")
12-
t.Cleanup(func() {
13-
os.Setenv("NO_COLOR", orig_NO_COLOR)
14-
os.Setenv("CLICOLOR", orig_CLICOLOR)
15-
os.Setenv("CLICOLOR_FORCE", orig_CLICOLOR_FORCE)
16-
})
17-
1810
tests := []struct {
1911
name string
2012
NO_COLOR string
@@ -60,10 +52,11 @@ func TestEnvColorDisabled(t *testing.T) {
6052
}
6153
for _, tt := range tests {
6254
t.Run(tt.name, func(t *testing.T) {
63-
os.Setenv("NO_COLOR", tt.NO_COLOR)
64-
os.Setenv("CLICOLOR", tt.CLICOLOR)
65-
os.Setenv("CLICOLOR_FORCE", tt.CLICOLOR_FORCE)
66-
55+
t.Cleanup(env.WithEnv(map[string]string{
56+
"CLICOLOR": tt.CLICOLOR,
57+
"NO_COLOR": tt.NO_COLOR,
58+
"CLICOLOR_FORCE": tt.CLICOLOR_FORCE,
59+
}))
6760
if got := EnvColorDisabled(); got != tt.want {
6861
t.Errorf("EnvColorDisabled(): want %v, got %v", tt.want, got)
6962
}
@@ -72,15 +65,6 @@ func TestEnvColorDisabled(t *testing.T) {
7265
}
7366

7467
func TestEnvColorForced(t *testing.T) {
75-
orig_NO_COLOR := os.Getenv("NO_COLOR")
76-
orig_CLICOLOR := os.Getenv("CLICOLOR")
77-
orig_CLICOLOR_FORCE := os.Getenv("CLICOLOR_FORCE")
78-
t.Cleanup(func() {
79-
os.Setenv("NO_COLOR", orig_NO_COLOR)
80-
os.Setenv("CLICOLOR", orig_CLICOLOR)
81-
os.Setenv("CLICOLOR_FORCE", orig_CLICOLOR_FORCE)
82-
})
83-
8468
tests := []struct {
8569
name string
8670
NO_COLOR string
@@ -133,10 +117,11 @@ func TestEnvColorForced(t *testing.T) {
133117
}
134118
for _, tt := range tests {
135119
t.Run(tt.name, func(t *testing.T) {
136-
os.Setenv("NO_COLOR", tt.NO_COLOR)
137-
os.Setenv("CLICOLOR", tt.CLICOLOR)
138-
os.Setenv("CLICOLOR_FORCE", tt.CLICOLOR_FORCE)
139-
120+
t.Cleanup(env.WithEnv(map[string]string{
121+
"CLICOLOR": tt.CLICOLOR,
122+
"NO_COLOR": tt.NO_COLOR,
123+
"CLICOLOR_FORCE": tt.CLICOLOR_FORCE,
124+
}))
140125
if got := EnvColorForced(); got != tt.want {
141126
t.Errorf("EnvColorForced(): want %v, got %v", tt.want, got)
142127
}

0 commit comments

Comments
 (0)