Skip to content

Commit d75faab

Browse files
committed
Creating a PR now always prioritizes an existing fork as a push target
Before: the default push target for the current branch in `pr create` was the first repository found among git remotes that has write access. Now: the default push target is the fork the base repo, if said fork exists and has write access, falling back to old behavior otherwise. This change in the default is to facilitate contributions to projects that have a hard requirement that all pull requests (even those opened by people with write access to that project) come from forks.
1 parent 2660561 commit d75faab

File tree

5 files changed

+177
-22
lines changed

5 files changed

+177
-22
lines changed

api/queries_repo.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package api
33
import (
44
"bytes"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"sort"
89
"strings"
@@ -220,6 +221,44 @@ func ForkRepo(client *Client, repo ghrepo.Interface) (*Repository, error) {
220221
}, nil
221222
}
222223

224+
// RepoFindFork finds a fork of repo affiliated with the viewer
225+
func RepoFindFork(client *Client, repo ghrepo.Interface) (*Repository, error) {
226+
result := struct {
227+
Repository struct {
228+
Forks struct {
229+
Nodes []Repository
230+
}
231+
}
232+
}{}
233+
234+
variables := map[string]interface{}{
235+
"owner": repo.RepoOwner(),
236+
"repo": repo.RepoName(),
237+
}
238+
239+
if err := client.GraphQL(`
240+
query($owner: String!, $repo: String!) {
241+
repository(owner: $owner, name: $repo) {
242+
forks(first: 1, affiliations: [OWNER, COLLABORATOR]) {
243+
nodes {
244+
id
245+
name
246+
owner { login }
247+
url
248+
}
249+
}
250+
}
251+
}
252+
`, variables, &result); err != nil {
253+
return nil, err
254+
}
255+
256+
if len(result.Repository.Forks.Nodes) > 0 {
257+
return &result.Repository.Forks.Nodes[0], nil
258+
}
259+
return nil, &NotFoundError{errors.New("no fork found")}
260+
}
261+
223262
// RepoCreateInput represents input parameters for RepoCreate
224263
type RepoCreateInput struct {
225264
Name string `json:"name"`

command/pr_create.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
193193
}
194194

195195
didForkRepo := false
196-
var headRemote *context.Remote
197196
if headRepoErr != nil {
198197
if baseRepo.IsPrivate {
199198
return fmt.Errorf("cannot fork private repository '%s'", ghrepo.FullName(baseRepo))
@@ -203,11 +202,26 @@ func prCreate(cmd *cobra.Command, _ []string) error {
203202
return fmt.Errorf("error forking repo: %w", err)
204203
}
205204
didForkRepo = true
205+
}
206+
207+
headBranchLabel := headBranch
208+
if !ghrepo.IsSame(baseRepo, headRepo) {
209+
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
210+
}
211+
212+
headRemote, err := repoContext.RemoteForRepo(headRepo)
213+
// There are two cases when an existing remote for the head repo will be
214+
// missing:
215+
// 1. the head repo was just created by auto-forking;
216+
// 2. an existing fork was discovered by quering the API.
217+
//
218+
// In either case, we want to add the head repo as a new git remote so we
219+
// can push to it.
220+
if err != nil {
206221
// TODO: support non-HTTPS git remote URLs
207-
baseRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(baseRepo))
208222
headRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(headRepo))
209-
// TODO: figure out what to name the new git remote
210-
gitRemote, err := git.AddRemote("fork", baseRepoURL, headRepoURL)
223+
// TODO: prevent clashes with another remote of a same name
224+
gitRemote, err := git.AddRemote("fork", headRepoURL, "")
211225
if err != nil {
212226
return fmt.Errorf("error adding remote: %w", err)
213227
}
@@ -218,18 +232,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
218232
}
219233
}
220234

221-
headBranchLabel := headBranch
222-
if !ghrepo.IsSame(baseRepo, headRepo) {
223-
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
224-
}
225-
226-
if headRemote == nil {
227-
headRemote, err = repoContext.RemoteForRepo(headRepo)
228-
if err != nil {
229-
return fmt.Errorf("git remote not found for head repository: %w", err)
230-
}
231-
}
232-
233235
pushTries := 0
234236
maxPushTries := 3
235237
for {

command/pr_create_test.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ func TestPRCreate(t *testing.T) {
1414
initBlankContext("OWNER/REPO", "feature")
1515
http := initFakeHTTP()
1616
http.StubRepoResponse("OWNER", "REPO")
17+
http.StubResponse(200, bytes.NewBufferString(`
18+
{ "data": { "repository": { "forks": { "nodes": [
19+
] } } } }
20+
`))
1721
http.StubResponse(200, bytes.NewBufferString(`
1822
{ "data": { "repository": { "pullRequests": { "nodes" : [
1923
] } } } }
@@ -34,7 +38,7 @@ func TestPRCreate(t *testing.T) {
3438
output, err := RunCommand(prCreateCmd, `pr create -t "my title" -b "my body"`)
3539
eq(t, err, nil)
3640

37-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
41+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
3842
reqBody := struct {
3943
Variables struct {
4044
Input struct {
@@ -61,6 +65,10 @@ func TestPRCreate_alreadyExists(t *testing.T) {
6165
initBlankContext("OWNER/REPO", "feature")
6266
http := initFakeHTTP()
6367
http.StubRepoResponse("OWNER", "REPO")
68+
http.StubResponse(200, bytes.NewBufferString(`
69+
{ "data": { "repository": { "forks": { "nodes": [
70+
] } } } }
71+
`))
6472
http.StubResponse(200, bytes.NewBufferString(`
6573
{ "data": { "repository": { "pullRequests": { "nodes": [
6674
{ "url": "https://github.com/OWNER/REPO/pull/123",
@@ -87,6 +95,10 @@ func TestPRCreate_web(t *testing.T) {
8795
initBlankContext("OWNER/REPO", "feature")
8896
http := initFakeHTTP()
8997
http.StubRepoResponse("OWNER", "REPO")
98+
http.StubResponse(200, bytes.NewBufferString(`
99+
{ "data": { "repository": { "forks": { "nodes": [
100+
] } } } }
101+
`))
90102

91103
cs, cmdTeardown := initCmdStubber()
92104
defer cmdTeardown()
@@ -113,6 +125,10 @@ func TestPRCreate_ReportsUncommittedChanges(t *testing.T) {
113125
http := initFakeHTTP()
114126

115127
http.StubRepoResponse("OWNER", "REPO")
128+
http.StubResponse(200, bytes.NewBufferString(`
129+
{ "data": { "repository": { "forks": { "nodes": [
130+
] } } } }
131+
`))
116132
http.StubResponse(200, bytes.NewBufferString(`
117133
{ "data": { "repository": { "pullRequests": { "nodes" : [
118134
] } } } }
@@ -232,6 +248,10 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) {
232248
initBlankContext("OWNER/REPO", "cool_bug-fixes")
233249
http := initFakeHTTP()
234250
http.StubRepoResponse("OWNER", "REPO")
251+
http.StubResponse(200, bytes.NewBufferString(`
252+
{ "data": { "repository": { "forks": { "nodes": [
253+
] } } } }
254+
`))
235255
http.StubResponse(200, bytes.NewBufferString(`
236256
{ "data": { "repository": { "pullRequests": { "nodes" : [
237257
] } } } }
@@ -273,7 +293,7 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) {
273293
output, err := RunCommand(prCreateCmd, `pr create`)
274294
eq(t, err, nil)
275295

276-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
296+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
277297
reqBody := struct {
278298
Variables struct {
279299
Input struct {
@@ -302,6 +322,10 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) {
302322
initBlankContext("OWNER/REPO", "feature")
303323
http := initFakeHTTP()
304324
http.StubRepoResponse("OWNER", "REPO")
325+
http.StubResponse(200, bytes.NewBufferString(`
326+
{ "data": { "repository": { "forks": { "nodes": [
327+
] } } } }
328+
`))
305329
http.StubResponse(200, bytes.NewBufferString(`
306330
{ "data": { "repository": { "pullRequests": { "nodes" : [
307331
] } } } }
@@ -344,7 +368,7 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) {
344368
output, err := RunCommand(prCreateCmd, `pr create`)
345369
eq(t, err, nil)
346370

347-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
371+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
348372
reqBody := struct {
349373
Variables struct {
350374
Input struct {
@@ -373,6 +397,10 @@ func TestPRCreate_survey_autofill(t *testing.T) {
373397
initBlankContext("OWNER/REPO", "feature")
374398
http := initFakeHTTP()
375399
http.StubRepoResponse("OWNER", "REPO")
400+
http.StubResponse(200, bytes.NewBufferString(`
401+
{ "data": { "repository": { "forks": { "nodes": [
402+
] } } } }
403+
`))
376404
http.StubResponse(200, bytes.NewBufferString(`
377405
{ "data": { "repository": { "pullRequests": { "nodes" : [
378406
] } } } }
@@ -396,7 +424,7 @@ func TestPRCreate_survey_autofill(t *testing.T) {
396424
output, err := RunCommand(prCreateCmd, `pr create -f`)
397425
eq(t, err, nil)
398426

399-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
427+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
400428
reqBody := struct {
401429
Variables struct {
402430
Input struct {
@@ -457,6 +485,10 @@ func TestPRCreate_defaults_error_interactive(t *testing.T) {
457485
initBlankContext("OWNER/REPO", "feature")
458486
http := initFakeHTTP()
459487
http.StubRepoResponse("OWNER", "REPO")
488+
http.StubResponse(200, bytes.NewBufferString(`
489+
{ "data": { "repository": { "forks": { "nodes": [
490+
] } } } }
491+
`))
460492
http.StubResponse(200, bytes.NewBufferString(`
461493
{ "data": { "createPullRequest": { "pullRequest": {
462494
"URL": "https://github.com/OWNER/REPO/pull/12"

context/context.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ func ResolveRemotesToRepos(remotes Remotes, client *api.Client, base string) (Re
5151
repos = append(repos, baseOverride)
5252
}
5353

54-
result := ResolvedRemotes{Remotes: remotes}
54+
result := ResolvedRemotes{
55+
Remotes: remotes,
56+
apiClient: client,
57+
}
5558
if hasBaseOverride {
5659
result.BaseOverride = baseOverride
5760
}
@@ -67,6 +70,7 @@ type ResolvedRemotes struct {
6770
BaseOverride ghrepo.Interface
6871
Remotes Remotes
6972
Network api.RepoNetworkResult
73+
apiClient *api.Client
7074
}
7175

7276
// BaseRepo is the first found repository in the "upstream", "github", "origin"
@@ -95,8 +99,30 @@ func (r ResolvedRemotes) BaseRepo() (*api.Repository, error) {
9599
return nil, errors.New("not found")
96100
}
97101

98-
// HeadRepo is the first found repository that has push access
102+
// HeadRepo is a fork of base repo (if any), or the first found repository that
103+
// has push access
99104
func (r ResolvedRemotes) HeadRepo() (*api.Repository, error) {
105+
baseRepo, err := r.BaseRepo()
106+
if err != nil {
107+
return nil, err
108+
}
109+
110+
// try to find a pushable fork among existing remotes
111+
for _, repo := range r.Network.Repositories {
112+
if repo != nil && repo.Parent != nil && repo.ViewerCanPush() && ghrepo.IsSame(repo.Parent, baseRepo) {
113+
return repo, nil
114+
}
115+
}
116+
117+
// a fork might still exist on GitHub, so let's query for it
118+
var notFound *api.NotFoundError
119+
if repo, err := api.RepoFindFork(r.apiClient, baseRepo); err == nil {
120+
return repo, nil
121+
} else if !errors.As(err, &notFound) {
122+
return nil, err
123+
}
124+
125+
// fall back to any listed repository that has push access
100126
for _, repo := range r.Network.Repositories {
101127
if repo != nil && repo.ViewerCanPush() {
102128
return repo, nil

context/remote_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package context
22

33
import (
4+
"bytes"
45
"errors"
56
"net/url"
67
"testing"
@@ -61,6 +62,14 @@ func Test_translateRemotes(t *testing.T) {
6162
}
6263

6364
func Test_resolvedRemotes_triangularSetup(t *testing.T) {
65+
http := &api.FakeHTTP{}
66+
apiClient := api.NewClient(api.ReplaceTripper(http))
67+
68+
http.StubResponse(200, bytes.NewBufferString(`
69+
{ "data": { "repository": { "forks": { "nodes": [
70+
] } } } }
71+
`))
72+
6473
resolved := ResolvedRemotes{
6574
BaseOverride: nil,
6675
Remotes: Remotes{
@@ -89,6 +98,7 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) {
8998
},
9099
},
91100
},
101+
apiClient: apiClient,
92102
}
93103

94104
baseRepo, err := resolved.BaseRepo()
@@ -118,6 +128,52 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) {
118128
}
119129
}
120130

131+
func Test_resolvedRemotes_forkLookup(t *testing.T) {
132+
http := &api.FakeHTTP{}
133+
apiClient := api.NewClient(api.ReplaceTripper(http))
134+
135+
http.StubResponse(200, bytes.NewBufferString(`
136+
{ "data": { "repository": { "forks": { "nodes": [
137+
{ "id": "FORKID",
138+
"url": "https://github.com/FORKOWNER/REPO",
139+
"name": "REPO",
140+
"owner": { "login": "FORKOWNER" }
141+
}
142+
] } } } }
143+
`))
144+
145+
resolved := ResolvedRemotes{
146+
BaseOverride: nil,
147+
Remotes: Remotes{
148+
&Remote{
149+
Remote: &git.Remote{Name: "origin"},
150+
Owner: "OWNER",
151+
Repo: "REPO",
152+
},
153+
},
154+
Network: api.RepoNetworkResult{
155+
Repositories: []*api.Repository{
156+
&api.Repository{
157+
Name: "NEWNAME",
158+
Owner: api.RepositoryOwner{Login: "NEWOWNER"},
159+
ViewerPermission: "READ",
160+
},
161+
},
162+
},
163+
apiClient: apiClient,
164+
}
165+
166+
headRepo, err := resolved.HeadRepo()
167+
if err != nil {
168+
t.Fatalf("got %v", err)
169+
}
170+
eq(t, ghrepo.FullName(headRepo), "FORKOWNER/REPO")
171+
_, err = resolved.RemoteForRepo(headRepo)
172+
if err == nil {
173+
t.Fatal("expected to not find a matching remote")
174+
}
175+
}
176+
121177
func Test_resolvedRemotes_clonedFork(t *testing.T) {
122178
resolved := ResolvedRemotes{
123179
BaseOverride: nil,

0 commit comments

Comments
 (0)