Skip to content

Commit 1fb0eef

Browse files
authored
Merge pull request cli#680 from cli/pr-create-push-default
Creating a PR now always prioritizes an existing fork as a push target
2 parents 0432419 + da2116f commit 1fb0eef

File tree

7 files changed

+193
-33
lines changed

7 files changed

+193
-33
lines changed

api/queries_repo.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/base64"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"sort"
910
"strings"
@@ -224,6 +225,49 @@ func ForkRepo(client *Client, repo ghrepo.Interface) (*Repository, error) {
224225
}, nil
225226
}
226227

228+
// RepoFindFork finds a fork of repo affiliated with the viewer
229+
func RepoFindFork(client *Client, repo ghrepo.Interface) (*Repository, error) {
230+
result := struct {
231+
Repository struct {
232+
Forks struct {
233+
Nodes []Repository
234+
}
235+
}
236+
}{}
237+
238+
variables := map[string]interface{}{
239+
"owner": repo.RepoOwner(),
240+
"repo": repo.RepoName(),
241+
}
242+
243+
if err := client.GraphQL(`
244+
query($owner: String!, $repo: String!) {
245+
repository(owner: $owner, name: $repo) {
246+
forks(first: 1, affiliations: [OWNER, COLLABORATOR]) {
247+
nodes {
248+
id
249+
name
250+
owner { login }
251+
url
252+
viewerPermission
253+
}
254+
}
255+
}
256+
}
257+
`, variables, &result); err != nil {
258+
return nil, err
259+
}
260+
261+
forks := result.Repository.Forks.Nodes
262+
// we check ViewerCanPush, even though we expect it to always be true per
263+
// `affiliations` condition, to guard against versions of GitHub with a
264+
// faulty `affiliations` implementation
265+
if len(forks) > 0 && forks[0].ViewerCanPush() {
266+
return &forks[0], nil
267+
}
268+
return nil, &NotFoundError{errors.New("no fork found")}
269+
}
270+
227271
// RepoCreateInput represents input parameters for RepoCreate
228272
type RepoCreateInput struct {
229273
Name string `json:"name"`

command/pr_create.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,25 @@ func prCreate(cmd *cobra.Command, _ []string) error {
225225
return fmt.Errorf("error forking repo: %w", err)
226226
}
227227
didForkRepo = true
228+
}
229+
230+
headBranchLabel := headBranch
231+
if !ghrepo.IsSame(baseRepo, headRepo) {
232+
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
233+
}
234+
235+
// There are two cases when an existing remote for the head repo will be
236+
// missing:
237+
// 1. the head repo was just created by auto-forking;
238+
// 2. an existing fork was discovered by quering the API.
239+
//
240+
// In either case, we want to add the head repo as a new git remote so we
241+
// can push to it.
242+
if err != nil {
228243
// TODO: support non-HTTPS git remote URLs
229-
baseRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(baseRepo))
230244
headRepoURL := fmt.Sprintf("https://github.com/%s.git", ghrepo.FullName(headRepo))
231-
// TODO: figure out what to name the new git remote
232-
gitRemote, err := git.AddRemote("fork", baseRepoURL, headRepoURL)
245+
// TODO: prevent clashes with another remote of a same name
246+
gitRemote, err := git.AddRemote("fork", headRepoURL)
233247
if err != nil {
234248
return fmt.Errorf("error adding remote: %w", err)
235249
}
@@ -240,11 +254,6 @@ func prCreate(cmd *cobra.Command, _ []string) error {
240254
}
241255
}
242256

243-
headBranchLabel := headBranch
244-
if !ghrepo.IsSame(baseRepo, headRepo) {
245-
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
246-
}
247-
248257
// automatically push the branch if it hasn't been pushed anywhere yet
249258
if headBranchPushedTo == nil {
250259
if headRemote == nil {

command/pr_create_test.go

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ func TestPRCreate(t *testing.T) {
1515
initBlankContext("OWNER/REPO", "feature")
1616
http := initFakeHTTP()
1717
http.StubRepoResponse("OWNER", "REPO")
18+
http.StubResponse(200, bytes.NewBufferString(`
19+
{ "data": { "repository": { "forks": { "nodes": [
20+
] } } } }
21+
`))
1822
http.StubResponse(200, bytes.NewBufferString(`
1923
{ "data": { "repository": { "pullRequests": { "nodes" : [
2024
] } } } }
@@ -37,7 +41,7 @@ func TestPRCreate(t *testing.T) {
3741
output, err := RunCommand(prCreateCmd, `pr create -t "my title" -b "my body"`)
3842
eq(t, err, nil)
3943

40-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
44+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
4145
reqBody := struct {
4246
Variables struct {
4347
Input struct {
@@ -64,6 +68,10 @@ func TestPRCreate_alreadyExists(t *testing.T) {
6468
initBlankContext("OWNER/REPO", "feature")
6569
http := initFakeHTTP()
6670
http.StubRepoResponse("OWNER", "REPO")
71+
http.StubResponse(200, bytes.NewBufferString(`
72+
{ "data": { "repository": { "forks": { "nodes": [
73+
] } } } }
74+
`))
6775
http.StubResponse(200, bytes.NewBufferString(`
6876
{ "data": { "repository": { "pullRequests": { "nodes": [
6977
{ "url": "https://github.com/OWNER/REPO/pull/123",
@@ -93,6 +101,10 @@ func TestPRCreate_alreadyExistsDifferentBase(t *testing.T) {
93101
initBlankContext("OWNER/REPO", "feature")
94102
http := initFakeHTTP()
95103
http.StubRepoResponse("OWNER", "REPO")
104+
http.StubResponse(200, bytes.NewBufferString(`
105+
{ "data": { "repository": { "forks": { "nodes": [
106+
] } } } }
107+
`))
96108
http.StubResponse(200, bytes.NewBufferString(`
97109
{ "data": { "repository": { "pullRequests": { "nodes": [
98110
{ "url": "https://github.com/OWNER/REPO/pull/123",
@@ -121,6 +133,10 @@ func TestPRCreate_web(t *testing.T) {
121133
initBlankContext("OWNER/REPO", "feature")
122134
http := initFakeHTTP()
123135
http.StubRepoResponse("OWNER", "REPO")
136+
http.StubResponse(200, bytes.NewBufferString(`
137+
{ "data": { "repository": { "forks": { "nodes": [
138+
] } } } }
139+
`))
124140

125141
cs, cmdTeardown := initCmdStubber()
126142
defer cmdTeardown()
@@ -149,6 +165,10 @@ func TestPRCreate_ReportsUncommittedChanges(t *testing.T) {
149165
http := initFakeHTTP()
150166

151167
http.StubRepoResponse("OWNER", "REPO")
168+
http.StubResponse(200, bytes.NewBufferString(`
169+
{ "data": { "repository": { "forks": { "nodes": [
170+
] } } } }
171+
`))
152172
http.StubResponse(200, bytes.NewBufferString(`
153173
{ "data": { "repository": { "pullRequests": { "nodes" : [
154174
] } } } }
@@ -272,6 +292,10 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) {
272292
initBlankContext("OWNER/REPO", "cool_bug-fixes")
273293
http := initFakeHTTP()
274294
http.StubRepoResponse("OWNER", "REPO")
295+
http.StubResponse(200, bytes.NewBufferString(`
296+
{ "data": { "repository": { "forks": { "nodes": [
297+
] } } } }
298+
`))
275299
http.StubResponse(200, bytes.NewBufferString(`
276300
{ "data": { "repository": { "pullRequests": { "nodes" : [
277301
] } } } }
@@ -315,7 +339,7 @@ func TestPRCreate_survey_defaults_multicommit(t *testing.T) {
315339
output, err := RunCommand(prCreateCmd, `pr create`)
316340
eq(t, err, nil)
317341

318-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
342+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
319343
reqBody := struct {
320344
Variables struct {
321345
Input struct {
@@ -344,6 +368,10 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) {
344368
initBlankContext("OWNER/REPO", "feature")
345369
http := initFakeHTTP()
346370
http.StubRepoResponse("OWNER", "REPO")
371+
http.StubResponse(200, bytes.NewBufferString(`
372+
{ "data": { "repository": { "forks": { "nodes": [
373+
] } } } }
374+
`))
347375
http.StubResponse(200, bytes.NewBufferString(`
348376
{ "data": { "repository": { "pullRequests": { "nodes" : [
349377
] } } } }
@@ -388,7 +416,7 @@ func TestPRCreate_survey_defaults_monocommit(t *testing.T) {
388416
output, err := RunCommand(prCreateCmd, `pr create`)
389417
eq(t, err, nil)
390418

391-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
419+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
392420
reqBody := struct {
393421
Variables struct {
394422
Input struct {
@@ -417,6 +445,10 @@ func TestPRCreate_survey_autofill(t *testing.T) {
417445
initBlankContext("OWNER/REPO", "feature")
418446
http := initFakeHTTP()
419447
http.StubRepoResponse("OWNER", "REPO")
448+
http.StubResponse(200, bytes.NewBufferString(`
449+
{ "data": { "repository": { "forks": { "nodes": [
450+
] } } } }
451+
`))
420452
http.StubResponse(200, bytes.NewBufferString(`
421453
{ "data": { "repository": { "pullRequests": { "nodes" : [
422454
] } } } }
@@ -442,7 +474,7 @@ func TestPRCreate_survey_autofill(t *testing.T) {
442474
output, err := RunCommand(prCreateCmd, `pr create -f`)
443475
eq(t, err, nil)
444476

445-
bodyBytes, _ := ioutil.ReadAll(http.Requests[2].Body)
477+
bodyBytes, _ := ioutil.ReadAll(http.Requests[3].Body)
446478
reqBody := struct {
447479
Variables struct {
448480
Input struct {
@@ -507,6 +539,10 @@ func TestPRCreate_defaults_error_interactive(t *testing.T) {
507539
initBlankContext("OWNER/REPO", "feature")
508540
http := initFakeHTTP()
509541
http.StubRepoResponse("OWNER", "REPO")
542+
http.StubResponse(200, bytes.NewBufferString(`
543+
{ "data": { "repository": { "forks": { "nodes": [
544+
] } } } }
545+
`))
510546
http.StubResponse(200, bytes.NewBufferString(`
511547
{ "data": { "createPullRequest": { "pullRequest": {
512548
"URL": "https://github.com/OWNER/REPO/pull/12"

command/repo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ func repoFork(cmd *cobra.Command, args []string) error {
347347
}
348348
}
349349
if remoteDesired {
350-
_, err := git.AddRemote("fork", forkedRepo.CloneURL, "")
350+
_, err := git.AddRemote("fork", forkedRepo.CloneURL)
351351
if err != nil {
352352
return fmt.Errorf("failed to add remote: %w", err)
353353
}

context/context.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ func ResolveRemotesToRepos(remotes Remotes, client *api.Client, base string) (Re
5151
repos = append(repos, baseOverride)
5252
}
5353

54-
result := ResolvedRemotes{Remotes: remotes}
54+
result := ResolvedRemotes{
55+
Remotes: remotes,
56+
apiClient: client,
57+
}
5558
if hasBaseOverride {
5659
result.BaseOverride = baseOverride
5760
}
@@ -67,6 +70,7 @@ type ResolvedRemotes struct {
6770
BaseOverride ghrepo.Interface
6871
Remotes Remotes
6972
Network api.RepoNetworkResult
73+
apiClient *api.Client
7074
}
7175

7276
// BaseRepo is the first found repository in the "upstream", "github", "origin"
@@ -95,8 +99,30 @@ func (r ResolvedRemotes) BaseRepo() (*api.Repository, error) {
9599
return nil, errors.New("not found")
96100
}
97101

98-
// HeadRepo is the first found repository that has push access
102+
// HeadRepo is a fork of base repo (if any), or the first found repository that
103+
// has push access
99104
func (r ResolvedRemotes) HeadRepo() (*api.Repository, error) {
105+
baseRepo, err := r.BaseRepo()
106+
if err != nil {
107+
return nil, err
108+
}
109+
110+
// try to find a pushable fork among existing remotes
111+
for _, repo := range r.Network.Repositories {
112+
if repo != nil && repo.Parent != nil && repo.ViewerCanPush() && ghrepo.IsSame(repo.Parent, baseRepo) {
113+
return repo, nil
114+
}
115+
}
116+
117+
// a fork might still exist on GitHub, so let's query for it
118+
var notFound *api.NotFoundError
119+
if repo, err := api.RepoFindFork(r.apiClient, baseRepo); err == nil {
120+
return repo, nil
121+
} else if !errors.As(err, &notFound) {
122+
return nil, err
123+
}
124+
125+
// fall back to any listed repository that has push access
100126
for _, repo := range r.Network.Repositories {
101127
if repo != nil && repo.ViewerCanPush() {
102128
return repo, nil

context/remote_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package context
22

33
import (
4+
"bytes"
45
"errors"
56
"net/url"
67
"testing"
@@ -61,6 +62,14 @@ func Test_translateRemotes(t *testing.T) {
6162
}
6263

6364
func Test_resolvedRemotes_triangularSetup(t *testing.T) {
65+
http := &api.FakeHTTP{}
66+
apiClient := api.NewClient(api.ReplaceTripper(http))
67+
68+
http.StubResponse(200, bytes.NewBufferString(`
69+
{ "data": { "repository": { "forks": { "nodes": [
70+
] } } } }
71+
`))
72+
6473
resolved := ResolvedRemotes{
6574
BaseOverride: nil,
6675
Remotes: Remotes{
@@ -89,6 +98,7 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) {
8998
},
9099
},
91100
},
101+
apiClient: apiClient,
92102
}
93103

94104
baseRepo, err := resolved.BaseRepo()
@@ -118,6 +128,53 @@ func Test_resolvedRemotes_triangularSetup(t *testing.T) {
118128
}
119129
}
120130

131+
func Test_resolvedRemotes_forkLookup(t *testing.T) {
132+
http := &api.FakeHTTP{}
133+
apiClient := api.NewClient(api.ReplaceTripper(http))
134+
135+
http.StubResponse(200, bytes.NewBufferString(`
136+
{ "data": { "repository": { "forks": { "nodes": [
137+
{ "id": "FORKID",
138+
"url": "https://github.com/FORKOWNER/REPO",
139+
"name": "REPO",
140+
"owner": { "login": "FORKOWNER" },
141+
"viewerPermission": "WRITE"
142+
}
143+
] } } } }
144+
`))
145+
146+
resolved := ResolvedRemotes{
147+
BaseOverride: nil,
148+
Remotes: Remotes{
149+
&Remote{
150+
Remote: &git.Remote{Name: "origin"},
151+
Owner: "OWNER",
152+
Repo: "REPO",
153+
},
154+
},
155+
Network: api.RepoNetworkResult{
156+
Repositories: []*api.Repository{
157+
&api.Repository{
158+
Name: "NEWNAME",
159+
Owner: api.RepositoryOwner{Login: "NEWOWNER"},
160+
ViewerPermission: "READ",
161+
},
162+
},
163+
},
164+
apiClient: apiClient,
165+
}
166+
167+
headRepo, err := resolved.HeadRepo()
168+
if err != nil {
169+
t.Fatalf("got %v", err)
170+
}
171+
eq(t, ghrepo.FullName(headRepo), "FORKOWNER/REPO")
172+
_, err = resolved.RemoteForRepo(headRepo)
173+
if err == nil {
174+
t.Fatal("expected to not find a matching remote")
175+
}
176+
}
177+
121178
func Test_resolvedRemotes_clonedFork(t *testing.T) {
122179
resolved := ResolvedRemotes{
123180
BaseOverride: nil,

0 commit comments

Comments
 (0)