Skip to content

Commit 6e9fab6

Browse files
authored
Merge pull request cli#1359 from cli/scopes-admin
Improvements to checking OAuth scopes
2 parents 2b96d2c + 28cd348 commit 6e9fab6

File tree

2 files changed

+105
-6
lines changed

2 files changed

+105
-6
lines changed

api/client.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,37 @@ func ReplaceTripper(tr http.RoundTripper) ClientOption {
8989

9090
var issuedScopesWarning bool
9191

92+
const (
93+
httpOAuthAppID = "X-Oauth-Client-Id"
94+
httpOAuthScopes = "X-Oauth-Scopes"
95+
)
96+
9297
// CheckScopes checks whether an OAuth scope is present in a response
9398
func CheckScopes(wantedScope string, cb func(string) error) ClientOption {
99+
wantedCandidates := []string{wantedScope}
100+
if strings.HasPrefix(wantedScope, "read:") {
101+
wantedCandidates = append(wantedCandidates, "admin:"+strings.TrimPrefix(wantedScope, "read:"))
102+
}
103+
94104
return func(tr http.RoundTripper) http.RoundTripper {
95105
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
96106
res, err := tr.RoundTrip(req)
97-
if err != nil || res.StatusCode > 299 || issuedScopesWarning {
107+
_, hasHeader := res.Header[httpOAuthAppID]
108+
if err != nil || res.StatusCode > 299 || !hasHeader || issuedScopesWarning {
98109
return res, err
99110
}
100111

101-
appID := res.Header.Get("X-Oauth-Client-Id")
102-
hasScopes := strings.Split(res.Header.Get("X-Oauth-Scopes"), ",")
112+
appID := res.Header.Get(httpOAuthAppID)
113+
hasScopes := strings.Split(res.Header.Get(httpOAuthScopes), ",")
103114

104115
hasWanted := false
116+
outer:
105117
for _, s := range hasScopes {
106-
if wantedScope == strings.TrimSpace(s) {
107-
hasWanted = true
108-
break
118+
for _, w := range wantedCandidates {
119+
if w == strings.TrimSpace(s) {
120+
hasWanted = true
121+
break outer
122+
}
109123
}
110124
}
111125

api/client_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"errors"
66
"io/ioutil"
7+
"net/http"
78
"reflect"
89
"testing"
910

@@ -91,5 +92,89 @@ func TestRESTError(t *testing.T) {
9192
}
9293
if httpErr.Error() != "HTTP 422: OH NO (https://api.github.com/repos/branch)" {
9394
t.Errorf("got %q", httpErr.Error())
95+
96+
}
97+
}
98+
99+
func Test_CheckScopes(t *testing.T) {
100+
tests := []struct {
101+
name string
102+
wantScope string
103+
responseApp string
104+
responseScopes string
105+
expectCallback bool
106+
}{
107+
{
108+
name: "missing read:org",
109+
wantScope: "read:org",
110+
responseApp: "APPID",
111+
responseScopes: "repo, gist",
112+
expectCallback: true,
113+
},
114+
{
115+
name: "has read:org",
116+
wantScope: "read:org",
117+
responseApp: "APPID",
118+
responseScopes: "repo, read:org, gist",
119+
expectCallback: false,
120+
},
121+
{
122+
name: "has admin:org",
123+
wantScope: "read:org",
124+
responseApp: "APPID",
125+
responseScopes: "repo, admin:org, gist",
126+
expectCallback: false,
127+
},
128+
{
129+
name: "no scopes in response",
130+
wantScope: "read:org",
131+
responseApp: "",
132+
responseScopes: "",
133+
expectCallback: false,
134+
},
135+
}
136+
for _, tt := range tests {
137+
t.Run(tt.name, func(t *testing.T) {
138+
tr := &httpmock.Registry{}
139+
tr.Register(httpmock.MatchAny, func(*http.Request) (*http.Response, error) {
140+
if tt.responseScopes == "" {
141+
return &http.Response{StatusCode: 200}, nil
142+
}
143+
return &http.Response{
144+
StatusCode: 200,
145+
Header: http.Header{
146+
"X-Oauth-Client-Id": []string{tt.responseApp},
147+
"X-Oauth-Scopes": []string{tt.responseScopes},
148+
},
149+
}, nil
150+
})
151+
152+
callbackInvoked := false
153+
var gotAppID string
154+
fn := CheckScopes(tt.wantScope, func(appID string) error {
155+
callbackInvoked = true
156+
gotAppID = appID
157+
return nil
158+
})
159+
160+
rt := fn(tr)
161+
req, err := http.NewRequest("GET", "https://api.github.com/hello", nil)
162+
if err != nil {
163+
t.Fatalf("unexpected error: %v", err)
164+
}
165+
166+
issuedScopesWarning = false
167+
_, err = rt.RoundTrip(req)
168+
if err != nil {
169+
t.Fatalf("unexpected error: %v", err)
170+
}
171+
172+
if tt.expectCallback != callbackInvoked {
173+
t.Fatalf("expected CheckScopes callback: %v", tt.expectCallback)
174+
}
175+
if tt.expectCallback && gotAppID != tt.responseApp {
176+
t.Errorf("unexpected app ID: %q", gotAppID)
177+
}
178+
})
94179
}
95180
}

0 commit comments

Comments
 (0)