Skip to content

Commit 4c75c8b

Browse files
committed
Reset base branch when URL is used
1 parent de6b1e0 commit 4c75c8b

File tree

7 files changed

+35
-74
lines changed

7 files changed

+35
-74
lines changed

api/queries_pr.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ func (pr *PullRequest) ChecksStatus() (summary PullRequestChecksStatus) {
205205
return
206206
}
207207

208-
func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, pr *PullRequest) (string, error) {
208+
func (c Client) PullRequestDiff(baseRepo ghrepo.Interface, prNumber int) (string, error) {
209209
url := fmt.Sprintf("https://api.github.com/repos/%s/pulls/%d",
210-
ghrepo.FullName(baseRepo), pr.Number)
210+
ghrepo.FullName(baseRepo), prNumber)
211211
req, err := http.NewRequest("GET", url, nil)
212212
if err != nil {
213213
return "", err

command/pr.go

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -305,17 +305,12 @@ func prView(cmd *cobra.Command, args []string) error {
305305
return err
306306
}
307307

308-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
309-
if err != nil {
310-
return err
311-
}
312-
313308
web, err := cmd.Flags().GetBool("web")
314309
if err != nil {
315310
return err
316311
}
317312

318-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
313+
pr, _, err := prFromArgs(ctx, apiClient, cmd, args)
319314
if err != nil {
320315
return err
321316
}
@@ -337,12 +332,7 @@ func prClose(cmd *cobra.Command, args []string) error {
337332
return err
338333
}
339334

340-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
341-
if err != nil {
342-
return err
343-
}
344-
345-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
335+
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
346336
if err != nil {
347337
return err
348338
}
@@ -372,12 +362,7 @@ func prReopen(cmd *cobra.Command, args []string) error {
372362
return err
373363
}
374364

375-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
376-
if err != nil {
377-
return err
378-
}
379-
380-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
365+
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
381366
if err != nil {
382367
return err
383368
}
@@ -409,12 +394,7 @@ func prMerge(cmd *cobra.Command, args []string) error {
409394
return err
410395
}
411396

412-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
413-
if err != nil {
414-
return err
415-
}
416-
417-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
397+
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
418398
if err != nil {
419399
return err
420400
}
@@ -652,12 +632,7 @@ func prReady(cmd *cobra.Command, args []string) error {
652632
return err
653633
}
654634

655-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
656-
if err != nil {
657-
return err
658-
}
659-
660-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
635+
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
661636
if err != nil {
662637
return err
663638
}

command/pr_checkout.go

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"fmt"
66
"os"
77
"os/exec"
8-
"regexp"
98

109
"github.com/spf13/cobra"
1110

@@ -27,22 +26,7 @@ func prCheckout(cmd *cobra.Command, args []string) error {
2726
return err
2827
}
2928

30-
var baseRepo ghrepo.Interface
31-
prString := args[0]
32-
r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`)
33-
if m := r.FindStringSubmatch(prString); m != nil {
34-
prString = m[3]
35-
baseRepo = ghrepo.New(m[1], m[2])
36-
}
37-
38-
if baseRepo == nil {
39-
baseRepo, err = determineBaseRepo(apiClient, cmd, ctx)
40-
if err != nil {
41-
return err
42-
}
43-
}
44-
45-
pr, err := prFromArgs(ctx, apiClient, baseRepo, []string{prString})
29+
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
4630
if err != nil {
4731
return err
4832
}
@@ -61,6 +45,7 @@ func prCheckout(cmd *cobra.Command, args []string) error {
6145

6246
var cmdQueue [][]string
6347
newBranchName := pr.HeadRefName
48+
6449
if headRemote != nil {
6550
// there is an existing git remote for PR head
6651
remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName)

command/pr_checkout_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ func TestPRCheckout_urlArg(t *testing.T) {
7676
return ctx
7777
}
7878
http := initFakeHTTP()
79+
http.StubRepoResponse("hubot", "REPO")
7980

8081
http.StubResponse(200, bytes.NewBufferString(`
8182
{ "data": { "repository": { "pullRequest": {
@@ -125,7 +126,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) {
125126
return ctx
126127
}
127128
http := initFakeHTTP()
128-
129+
http.StubRepoResponse("OWNER", "REPO")
129130
http.StubResponse(200, bytes.NewBufferString(`
130131
{ "data": { "repository": { "pullRequest": {
131132
"number": 123,
@@ -160,7 +161,7 @@ func TestPRCheckout_urlArg_differentBase(t *testing.T) {
160161
eq(t, err, nil)
161162
eq(t, output.String(), "")
162163

163-
bodyBytes, _ := ioutil.ReadAll(http.Requests[0].Body)
164+
bodyBytes, _ := ioutil.ReadAll(http.Requests[1].Body)
164165
reqBody := struct {
165166
Variables struct {
166167
Owner string

command/pr_diff.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,12 @@ func prDiff(cmd *cobra.Command, args []string) error {
3636
return err
3737
}
3838

39-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
40-
if err != nil {
41-
return fmt.Errorf("could not determine base repo: %w", err)
42-
}
43-
44-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
39+
pr, baseRepo, err := prFromArgs(ctx, apiClient, cmd, args)
4540
if err != nil {
4641
return fmt.Errorf("could not find pull request: %w", err)
4742
}
4843

49-
diff, err := apiClient.PullRequestDiff(baseRepo, pr)
44+
diff, err := apiClient.PullRequestDiff(baseRepo, pr.Number)
5045
if err != nil {
5146
return fmt.Errorf("could not find pull request diff: %w", err)
5247
}

command/pr_lookup.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,36 @@ import (
1010
"github.com/cli/cli/context"
1111
"github.com/cli/cli/git"
1212
"github.com/cli/cli/internal/ghrepo"
13+
"github.com/spf13/cobra"
1314
)
1415

15-
func prFromArgs(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, args []string) (*api.PullRequest, error) {
16+
func prFromArgs(ctx context.Context, apiClient *api.Client, cmd *cobra.Command, args []string) (*api.PullRequest, ghrepo.Interface, error) {
17+
repo, err := determineBaseRepo(apiClient, cmd, ctx)
18+
if err != nil {
19+
return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
20+
}
21+
1622
if len(args) == 0 {
17-
return prForCurrentBranch(ctx, apiClient, repo)
23+
pr, err := prForCurrentBranch(ctx, apiClient, repo)
24+
return pr, repo, err
1825
}
1926

20-
// First check to see if the prString is a url
27+
// First check to see if the prString is a url, return repo from url if found
2128
prString := args[0]
22-
pr, err := prFromURL(ctx, apiClient, repo, prString)
29+
pr, r, err := prFromURL(ctx, apiClient, prString)
2330
if pr != nil || err != nil {
24-
return pr, err
31+
return pr, r, err
2532
}
2633

2734
// Next see if the prString is a number and use that to look up the url
2835
pr, err = prFromNumberString(ctx, apiClient, repo, prString)
2936
if pr != nil || err != nil {
30-
return pr, err
37+
return pr, repo, err
3138
}
3239

3340
// Last see if it is a branch name
34-
return api.PullRequestForBranch(apiClient, repo, "", prString)
41+
pr, err = api.PullRequestForBranch(apiClient, repo, "", prString)
42+
return pr, repo, err
3543
}
3644

3745
func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
@@ -42,14 +50,16 @@ func prFromNumberString(ctx context.Context, apiClient *api.Client, repo ghrepo.
4250
return nil, nil
4351
}
4452

45-
func prFromURL(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface, s string) (*api.PullRequest, error) {
53+
func prFromURL(ctx context.Context, apiClient *api.Client, s string) (*api.PullRequest, ghrepo.Interface, error) {
4654
r := regexp.MustCompile(`^https://github\.com/([^/]+)/([^/]+)/pull/(\d+)`)
4755
if m := r.FindStringSubmatch(s); m != nil {
56+
repo := ghrepo.New(m[1], m[2])
4857
prNumberString := m[3]
49-
return prFromNumberString(ctx, apiClient, repo, prNumberString)
58+
pr, err := prFromNumberString(ctx, apiClient, repo, prNumberString)
59+
return pr, repo, err
5060
}
5161

52-
return nil, nil
62+
return nil, nil, nil
5363
}
5464

5565
func prForCurrentBranch(ctx context.Context, apiClient *api.Client, repo ghrepo.Interface) (*api.PullRequest, error) {

command/pr_review.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,7 @@ func prReview(cmd *cobra.Command, args []string) error {
8989
return err
9090
}
9191

92-
baseRepo, err := determineBaseRepo(apiClient, cmd, ctx)
93-
if err != nil {
94-
return fmt.Errorf("could not determine base repo: %w", err)
95-
}
96-
97-
pr, err := prFromArgs(ctx, apiClient, baseRepo, args)
92+
pr, _, err := prFromArgs(ctx, apiClient, cmd, args)
9893
if err != nil {
9994
return err
10095
}

0 commit comments

Comments
 (0)