Skip to content

Commit 93c8fc1

Browse files
committed
Add tests for GraphQL introspection
1 parent 0ef2863 commit 93c8fc1

File tree

4 files changed

+191
-14
lines changed

4 files changed

+191
-14
lines changed

api/cache.go

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,28 @@ import (
1414
"time"
1515
)
1616

17+
func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client {
18+
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
19+
return &http.Client{
20+
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
21+
}
22+
}
23+
1724
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
1825
func CacheReponse(ttl time.Duration, dir string) ClientOption {
1926
return func(tr http.RoundTripper) http.RoundTripper {
2027
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
2128
key, keyErr := cacheKey(req)
2229
cacheFile := filepath.Join(dir, key)
2330
if keyErr == nil {
31+
// TODO: make thread-safe
2432
if res, err := readCache(ttl, cacheFile, req); err == nil {
2533
return res, nil
2634
}
2735
}
2836
res, err := tr.RoundTrip(req)
2937
if err == nil && keyErr == nil {
38+
// TODO: make thread-safe
3039
_ = writeCache(cacheFile, res)
3140
}
3241
return res, err
@@ -53,12 +62,16 @@ func cacheKey(req *http.Request) (string, error) {
5362
return fmt.Sprintf("%x", digest), nil
5463
}
5564

65+
type readCloser struct {
66+
io.Reader
67+
io.Closer
68+
}
69+
5670
func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) {
5771
f, err := os.Open(cacheFile)
5872
if err != nil {
5973
return nil, err
6074
}
61-
defer f.Close()
6275

6376
fs, err := f.Stat()
6477
if err != nil {
@@ -70,7 +83,14 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re
7083
return nil, errors.New("cache expired")
7184
}
7285

73-
return http.ReadResponse(bufio.NewReader(f), req)
86+
res, err := http.ReadResponse(bufio.NewReader(f), req)
87+
if res != nil {
88+
res.Body = &readCloser{
89+
Reader: res.Body,
90+
Closer: f,
91+
}
92+
}
93+
return res, err
7494
}
7595

7696
func writeCache(cacheFile string, res *http.Response) error {

api/cache_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package api
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io"
7+
"io/ioutil"
8+
"net/http"
9+
"path/filepath"
10+
"testing"
11+
"time"
12+
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func Test_CacheReponse(t *testing.T) {
18+
counter := 0
19+
fakeHTTP := funcTripper{
20+
roundTrip: func(req *http.Request) (*http.Response, error) {
21+
counter += 1
22+
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
23+
return &http.Response{
24+
StatusCode: 200,
25+
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
26+
}, nil
27+
},
28+
}
29+
30+
cacheDir := filepath.Join(t.TempDir(), "gh-cli-cache")
31+
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP), CacheReponse(time.Minute, cacheDir))
32+
33+
do := func(method, url string, body io.Reader) (string, error) {
34+
req, err := http.NewRequest(method, url, body)
35+
if err != nil {
36+
return "", err
37+
}
38+
res, err := httpClient.Do(req)
39+
if err != nil {
40+
return "", err
41+
}
42+
resBody, err := ioutil.ReadAll(res.Body)
43+
if err != nil {
44+
err = fmt.Errorf("ReadAll: %w", err)
45+
}
46+
return string(resBody), err
47+
}
48+
49+
res1, err := do("GET", "http://example.com/path", nil)
50+
require.NoError(t, err)
51+
assert.Equal(t, "1: GET http://example.com/path", res1)
52+
res2, err := do("GET", "http://example.com/path", nil)
53+
require.NoError(t, err)
54+
assert.Equal(t, "1: GET http://example.com/path", res2)
55+
56+
res3, err := do("GET", "http://example.com/path2", nil)
57+
require.NoError(t, err)
58+
assert.Equal(t, "2: GET http://example.com/path2", res3)
59+
60+
res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
61+
require.NoError(t, err)
62+
assert.Equal(t, "3: POST http://example.com/path", res4)
63+
res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
64+
require.NoError(t, err)
65+
assert.Equal(t, "3: POST http://example.com/path", res5)
66+
67+
res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`))
68+
require.NoError(t, err)
69+
assert.Equal(t, "4: POST http://example.com/path", res6)
70+
}

api/queries_pr.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ import (
66
"fmt"
77
"io"
88
"net/http"
9-
"os"
10-
"path/filepath"
119
"strings"
1210
"time"
1311

@@ -268,14 +266,8 @@ func determinePullRequestFeatures(httpClient *http.Client, hostname string) (prF
268266
} `graphql:"Commit: __type(name: \"Commit\")"`
269267
}
270268

271-
cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
272-
cacheTTL := time.Duration(24 * time.Hour)
273-
cachedClient := &http.Client{
274-
Transport: CacheReponse(cacheTTL, cacheDir)(httpClient.Transport),
275-
}
276-
277-
v4 := graphQLClient(cachedClient, hostname)
278-
err = v4.Query(context.Background(), &featureDetection, nil)
269+
v4 := graphQLClient(httpClient, hostname)
270+
err = v4.QueryNamed(context.Background(), "PullRequest_fields", &featureDetection, nil)
279271
if err != nil {
280272
return
281273
}
@@ -315,7 +307,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
315307
ReviewRequested edges
316308
}
317309

318-
prFeatures, err := determinePullRequestFeatures(client.http, repo.RepoHost())
310+
cachedClient := makeCachedClient(client.http, time.Hour*24)
311+
prFeatures, err := determinePullRequestFeatures(cachedClient, repo.RepoHost())
319312
if err != nil {
320313
return nil, err
321314
}
@@ -483,7 +476,8 @@ func PullRequests(client *Client, repo ghrepo.Interface, currentPRNumber int, cu
483476
}
484477

485478
func prCommitsFragment(httpClient *http.Client, hostname string) (string, error) {
486-
if prFeatures, err := determinePullRequestFeatures(httpClient, hostname); err != nil {
479+
cachedClient := makeCachedClient(httpClient, time.Hour*24)
480+
if prFeatures, err := determinePullRequestFeatures(cachedClient, hostname); err != nil {
487481
return "", err
488482
} else if !prFeatures.HasStatusCheckRollup {
489483
return "", nil

api/queries_pr_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package api
22

33
import (
4+
"reflect"
45
"testing"
56

7+
"github.com/MakeNowJust/heredoc"
68
"github.com/cli/cli/internal/ghrepo"
79
"github.com/cli/cli/pkg/httpmock"
810
)
@@ -45,3 +47,94 @@ func TestBranchDeleteRemote(t *testing.T) {
4547
})
4648
}
4749
}
50+
51+
func Test_determinePullRequestFeatures(t *testing.T) {
52+
tests := []struct {
53+
name string
54+
hostname string
55+
queryResponse string
56+
wantPrFeatures pullRequestFeature
57+
wantErr bool
58+
}{
59+
{
60+
name: "github.com",
61+
hostname: "github.com",
62+
wantPrFeatures: pullRequestFeature{
63+
HasReviewDecision: true,
64+
HasStatusCheckRollup: true,
65+
},
66+
wantErr: false,
67+
},
68+
{
69+
name: "GHE empty response",
70+
hostname: "git.my.org",
71+
queryResponse: heredoc.Doc(`
72+
{"data": {}}
73+
`),
74+
wantPrFeatures: pullRequestFeature{
75+
HasReviewDecision: false,
76+
HasStatusCheckRollup: false,
77+
},
78+
wantErr: false,
79+
},
80+
{
81+
name: "GHE has reviewDecision",
82+
hostname: "git.my.org",
83+
queryResponse: heredoc.Doc(`
84+
{"data": {
85+
"PullRequest": {
86+
"fields": [
87+
{"name": "foo"},
88+
{"name": "reviewDecision"}
89+
]
90+
}
91+
} }
92+
`),
93+
wantPrFeatures: pullRequestFeature{
94+
HasReviewDecision: true,
95+
HasStatusCheckRollup: false,
96+
},
97+
wantErr: false,
98+
},
99+
{
100+
name: "GHE has statusCheckRollup",
101+
hostname: "git.my.org",
102+
queryResponse: heredoc.Doc(`
103+
{"data": {
104+
"Commit": {
105+
"fields": [
106+
{"name": "foo"},
107+
{"name": "statusCheckRollup"}
108+
]
109+
}
110+
} }
111+
`),
112+
wantPrFeatures: pullRequestFeature{
113+
HasReviewDecision: false,
114+
HasStatusCheckRollup: true,
115+
},
116+
wantErr: false,
117+
},
118+
}
119+
for _, tt := range tests {
120+
t.Run(tt.name, func(t *testing.T) {
121+
fakeHTTP := &httpmock.Registry{}
122+
httpClient := NewHTTPClient(ReplaceTripper(fakeHTTP))
123+
124+
if tt.queryResponse != "" {
125+
fakeHTTP.Register(
126+
httpmock.GraphQL(`query PullRequest_fields\b`),
127+
httpmock.StringResponse(tt.queryResponse))
128+
}
129+
130+
gotPrFeatures, err := determinePullRequestFeatures(httpClient, tt.hostname)
131+
if (err != nil) != tt.wantErr {
132+
t.Errorf("determinePullRequestFeatures() error = %v, wantErr %v", err, tt.wantErr)
133+
return
134+
}
135+
if !reflect.DeepEqual(gotPrFeatures, tt.wantPrFeatures) {
136+
t.Errorf("determinePullRequestFeatures() = %v, want %v", gotPrFeatures, tt.wantPrFeatures)
137+
}
138+
})
139+
}
140+
}

0 commit comments

Comments
 (0)