Skip to content

Commit 0432419

Browse files
authored
Merge pull request cli#704 from cli/branch-push-detection
Avoid auto-forking/pushing an already pushed branch in `pr create`
2 parents c7b0abd + cba8331 commit 0432419

File tree

5 files changed

+262
-37
lines changed

5 files changed

+262
-37
lines changed

command/pr_checkout.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func prCheckout(cmd *cobra.Command, args []string) error {
6666
cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec})
6767

6868
// local branch already exists
69-
if git.VerifyRef("refs/heads/" + newBranchName) {
69+
if _, err := git.ShowRefs("refs/heads/" + newBranchName); err == nil {
7070
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
7171
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
7272
} else {

command/pr_checkout_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestPRCheckout_sameRepo(t *testing.T) {
4646
ranCommands := [][]string{}
4747
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
4848
switch strings.Join(cmd.Args, " ") {
49-
case "git show-ref --verify --quiet refs/heads/feature":
49+
case "git show-ref --verify -- refs/heads/feature":
5050
return &errorStub{"exit status: 1"}
5151
default:
5252
ranCommands = append(ranCommands, cmd.Args)
@@ -98,7 +98,7 @@ func TestPRCheckout_urlArg(t *testing.T) {
9898
ranCommands := [][]string{}
9999
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
100100
switch strings.Join(cmd.Args, " ") {
101-
case "git show-ref --verify --quiet refs/heads/feature":
101+
case "git show-ref --verify -- refs/heads/feature":
102102
return &errorStub{"exit status: 1"}
103103
default:
104104
ranCommands = append(ranCommands, cmd.Args)
@@ -147,7 +147,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) {
147147
ranCommands := [][]string{}
148148
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
149149
switch strings.Join(cmd.Args, " ") {
150-
case "git show-ref --verify --quiet refs/heads/feature":
150+
case "git show-ref --verify -- refs/heads/feature":
151151
return &errorStub{"exit status: 1"}
152152
default:
153153
ranCommands = append(ranCommands, cmd.Args)
@@ -210,7 +210,7 @@ func TestPRCheckout_branchArg(t *testing.T) {
210210
ranCommands := [][]string{}
211211
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
212212
switch strings.Join(cmd.Args, " ") {
213-
case "git show-ref --verify --quiet refs/heads/feature":
213+
case "git show-ref --verify -- refs/heads/feature":
214214
return &errorStub{"exit status: 1"}
215215
default:
216216
ranCommands = append(ranCommands, cmd.Args)
@@ -260,7 +260,7 @@ func TestPRCheckout_existingBranch(t *testing.T) {
260260
ranCommands := [][]string{}
261261
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
262262
switch strings.Join(cmd.Args, " ") {
263-
case "git show-ref --verify --quiet refs/heads/feature":
263+
case "git show-ref --verify -- refs/heads/feature":
264264
return &test.OutputStub{}
265265
default:
266266
ranCommands = append(ranCommands, cmd.Args)
@@ -313,7 +313,7 @@ func TestPRCheckout_differentRepo_remoteExists(t *testing.T) {
313313
ranCommands := [][]string{}
314314
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
315315
switch strings.Join(cmd.Args, " ") {
316-
case "git show-ref --verify --quiet refs/heads/feature":
316+
case "git show-ref --verify -- refs/heads/feature":
317317
return &errorStub{"exit status: 1"}
318318
default:
319319
ranCommands = append(ranCommands, cmd.Args)

command/pr_create.go

Lines changed: 87 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"errors"
55
"fmt"
66
"net/url"
7+
"strings"
78
"time"
89

910
"github.com/cli/cli/api"
@@ -75,7 +76,27 @@ func prCreate(cmd *cobra.Command, _ []string) error {
7576
if err != nil {
7677
return fmt.Errorf("could not determine the current branch: %w", err)
7778
}
78-
headRepo, headRepoErr := repoContext.HeadRepo()
79+
80+
var headRepo ghrepo.Interface
81+
var headRemote *context.Remote
82+
83+
// determine whether the head branch is already pushed to a remote
84+
headBranchPushedTo := determineTrackingBranch(remotes, headBranch)
85+
if headBranchPushedTo != nil {
86+
for _, r := range remotes {
87+
if r.Name != headBranchPushedTo.RemoteName {
88+
continue
89+
}
90+
headRepo = r
91+
headRemote = r
92+
break
93+
}
94+
}
95+
96+
// otherwise, determine the head repository with info obtained from the API
97+
if headRepo == nil {
98+
headRepo, _ = repoContext.HeadRepo()
99+
}
79100

80101
baseBranch, err := cmd.Flags().GetString("base")
81102
if err != nil {
@@ -193,8 +214,9 @@ func prCreate(cmd *cobra.Command, _ []string) error {
193214
}
194215

195216
didForkRepo := false
196-
var headRemote *context.Remote
197-
if headRepoErr != nil {
217+
// if a head repository could not be determined so far, automatically create
218+
// one by forking the base repository
219+
if headRepo == nil {
198220
if baseRepo.IsPrivate {
199221
return fmt.Errorf("cannot fork private repository '%s'", ghrepo.FullName(baseRepo))
200222
}
@@ -223,29 +245,31 @@ func prCreate(cmd *cobra.Command, _ []string) error {
223245
headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch)
224246
}
225247

226-
if headRemote == nil {
227-
headRemote, err = repoContext.RemoteForRepo(headRepo)
228-
if err != nil {
229-
return fmt.Errorf("git remote not found for head repository: %w", err)
248+
// automatically push the branch if it hasn't been pushed anywhere yet
249+
if headBranchPushedTo == nil {
250+
if headRemote == nil {
251+
headRemote, err = repoContext.RemoteForRepo(headRepo)
252+
if err != nil {
253+
return fmt.Errorf("git remote not found for head repository: %w", err)
254+
}
230255
}
231-
}
232256

233-
pushTries := 0
234-
maxPushTries := 3
235-
for {
236-
// TODO: respect existing upstream configuration of the current branch
237-
if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil {
238-
if didForkRepo && pushTries < maxPushTries {
239-
pushTries++
240-
// first wait 2 seconds after forking, then 4s, then 6s
241-
waitSeconds := 2 * pushTries
242-
fmt.Fprintf(cmd.ErrOrStderr(), "waiting %s before retrying...\n", utils.Pluralize(waitSeconds, "second"))
243-
time.Sleep(time.Duration(waitSeconds) * time.Second)
244-
continue
257+
pushTries := 0
258+
maxPushTries := 3
259+
for {
260+
if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil {
261+
if didForkRepo && pushTries < maxPushTries {
262+
pushTries++
263+
// first wait 2 seconds after forking, then 4s, then 6s
264+
waitSeconds := 2 * pushTries
265+
fmt.Fprintf(cmd.ErrOrStderr(), "waiting %s before retrying...\n", utils.Pluralize(waitSeconds, "second"))
266+
time.Sleep(time.Duration(waitSeconds) * time.Second)
267+
continue
268+
}
269+
return err
245270
}
246-
return err
271+
break
247272
}
248-
break
249273
}
250274

251275
if action == SubmitAction {
@@ -275,6 +299,47 @@ func prCreate(cmd *cobra.Command, _ []string) error {
275299
return nil
276300
}
277301

302+
func determineTrackingBranch(remotes context.Remotes, headBranch string) *git.TrackingRef {
303+
refsForLookup := []string{"HEAD"}
304+
var trackingRefs []git.TrackingRef
305+
306+
headBranchConfig := git.ReadBranchConfig(headBranch)
307+
if headBranchConfig.RemoteName != "" {
308+
tr := git.TrackingRef{
309+
RemoteName: headBranchConfig.RemoteName,
310+
BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"),
311+
}
312+
trackingRefs = append(trackingRefs, tr)
313+
refsForLookup = append(refsForLookup, tr.String())
314+
}
315+
316+
for _, remote := range remotes {
317+
tr := git.TrackingRef{
318+
RemoteName: remote.Name,
319+
BranchName: headBranch,
320+
}
321+
trackingRefs = append(trackingRefs, tr)
322+
refsForLookup = append(refsForLookup, tr.String())
323+
}
324+
325+
resolvedRefs, _ := git.ShowRefs(refsForLookup...)
326+
if len(resolvedRefs) > 1 {
327+
for _, r := range resolvedRefs[1:] {
328+
if r.Hash != resolvedRefs[0].Hash {
329+
continue
330+
}
331+
for _, tr := range trackingRefs {
332+
if tr.String() != r.Name {
333+
continue
334+
}
335+
return &tr
336+
}
337+
}
338+
}
339+
340+
return nil
341+
}
342+
278343
func generateCompareURL(r ghrepo.Interface, base, head, title, body string) string {
279344
u := fmt.Sprintf(
280345
"https://github.com/%s/compare/%s...%s?expand=1",

0 commit comments

Comments
 (0)