Skip to content

Commit ad12087

Browse files
authored
Merge pull request cli#96 from github/pr-current-branch
Improve detecting PR for the current branch
2 parents 9a3b032 + 3e06cff commit ad12087

File tree

7 files changed

+259
-113
lines changed

7 files changed

+259
-113
lines changed

api/queries_pr.go

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package api
22

33
import (
44
"fmt"
5+
"strings"
56
)
67

78
type PullRequestsPayload struct {
@@ -116,7 +117,7 @@ type Repo interface {
116117
RepoOwner() string
117118
}
118119

119-
func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername string) (*PullRequestsPayload, error) {
120+
func PullRequests(client *Client, ghRepo Repo, currentPRNumber int, currentPRHeadRef, currentUsername string) (*PullRequestsPayload, error) {
120121
type edges struct {
121122
Edges []struct {
122123
Node PullRequest
@@ -130,18 +131,18 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st
130131
type response struct {
131132
Repository struct {
132133
PullRequests edges
134+
PullRequest *PullRequest
133135
}
134136
ViewerCreated edges
135137
ReviewRequested edges
136138
}
137139

138-
query := `
140+
fragments := `
139141
fragment pr on PullRequest {
140142
number
141143
title
142144
url
143145
headRefName
144-
headRefName
145146
headRepositoryOwner {
146147
login
147148
}
@@ -170,16 +171,32 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st
170171
...pr
171172
reviewDecision
172173
}
173-
query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
174-
repository(owner: $owner, name: $repo) {
175-
pullRequests(headRefName: $headRefName, states: OPEN, first: 1) {
176-
edges {
177-
node {
178-
...prWithReviews
179-
}
180-
}
181-
}
182-
}
174+
`
175+
176+
queryPrefix := `
177+
query($owner: String!, $repo: String!, $headRefName: String!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
178+
repository(owner: $owner, name: $repo) {
179+
pullRequests(headRefName: $headRefName, states: OPEN, first: $per_page) {
180+
edges {
181+
node {
182+
...prWithReviews
183+
}
184+
}
185+
}
186+
}
187+
`
188+
if currentPRNumber > 0 {
189+
queryPrefix = `
190+
query($owner: String!, $repo: String!, $number: Int!, $viewerQuery: String!, $reviewerQuery: String!, $per_page: Int = 10) {
191+
repository(owner: $owner, name: $repo) {
192+
pullRequest(number: $number) {
193+
...prWithReviews
194+
}
195+
}
196+
`
197+
}
198+
199+
query := fragments + queryPrefix + `
183200
viewerCreated: search(query: $viewerQuery, type: ISSUE, first: $per_page) {
184201
edges {
185202
node {
@@ -201,20 +218,26 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st
201218
}
202219
}
203220
}
204-
`
221+
`
205222

206223
owner := ghRepo.RepoOwner()
207224
repo := ghRepo.RepoName()
208225

209226
viewerQuery := fmt.Sprintf("repo:%s/%s state:open is:pr author:%s", owner, repo, currentUsername)
210227
reviewerQuery := fmt.Sprintf("repo:%s/%s state:open review-requested:%s", owner, repo, currentUsername)
211228

229+
branchWithoutOwner := currentPRHeadRef
230+
if idx := strings.Index(currentPRHeadRef, ":"); idx >= 0 {
231+
branchWithoutOwner = currentPRHeadRef[idx+1:]
232+
}
233+
212234
variables := map[string]interface{}{
213235
"viewerQuery": viewerQuery,
214236
"reviewerQuery": reviewerQuery,
215237
"owner": owner,
216238
"repo": repo,
217-
"headRefName": currentBranch,
239+
"headRefName": branchWithoutOwner,
240+
"number": currentPRNumber,
218241
}
219242

220243
var resp response
@@ -233,9 +256,13 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st
233256
reviewRequested = append(reviewRequested, edge.Node)
234257
}
235258

236-
var currentPR *PullRequest
237-
for _, edge := range resp.Repository.PullRequests.Edges {
238-
currentPR = &edge.Node
259+
var currentPR = resp.Repository.PullRequest
260+
if currentPR == nil {
261+
for _, edge := range resp.Repository.PullRequests.Edges {
262+
if edge.Node.HeadLabel() == currentPRHeadRef {
263+
currentPR = &edge.Node
264+
}
265+
}
239266
}
240267

241268
payload := PullRequestsPayload{
@@ -289,36 +316,42 @@ func PullRequestByNumber(client *Client, ghRepo Repo, number int) (*PullRequest,
289316
return &resp.Repository.PullRequest, nil
290317
}
291318

292-
func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRequest, error) {
319+
func PullRequestForBranch(client *Client, ghRepo Repo, branch string) (*PullRequest, error) {
293320
type response struct {
294321
Repository struct {
295322
PullRequests struct {
296-
Edges []struct {
297-
Node PullRequest
298-
}
323+
Nodes []PullRequest
299324
}
300325
}
301326
}
302327

303328
query := `
304-
query($owner: String!, $repo: String!, $headRefName: String!) {
305-
repository(owner: $owner, name: $repo) {
306-
pullRequests(headRefName: $headRefName, states: OPEN, first: 1) {
307-
edges {
308-
node {
309-
number
310-
title
311-
url
312-
}
313-
}
314-
}
315-
}
316-
}`
329+
query($owner: String!, $repo: String!, $headRefName: String!) {
330+
repository(owner: $owner, name: $repo) {
331+
pullRequests(headRefName: $headRefName, states: OPEN, first: 30) {
332+
nodes {
333+
number
334+
title
335+
url
336+
headRefName
337+
headRepositoryOwner {
338+
login
339+
}
340+
isCrossRepository
341+
}
342+
}
343+
}
344+
}`
345+
346+
branchWithoutOwner := branch
347+
if idx := strings.Index(branch, ":"); idx >= 0 {
348+
branchWithoutOwner = branch[idx+1:]
349+
}
317350

318351
variables := map[string]interface{}{
319352
"owner": ghRepo.RepoOwner(),
320353
"repo": ghRepo.RepoName(),
321-
"headRefName": branch,
354+
"headRefName": branchWithoutOwner,
322355
}
323356

324357
var resp response
@@ -327,12 +360,13 @@ func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRe
327360
return nil, err
328361
}
329362

330-
prs := []PullRequest{}
331-
for _, edge := range resp.Repository.PullRequests.Edges {
332-
prs = append(prs, edge.Node)
363+
for _, pr := range resp.Repository.PullRequests.Nodes {
364+
if pr.HeadLabel() == branch {
365+
return &pr, nil
366+
}
333367
}
334368

335-
return prs, nil
369+
return nil, fmt.Errorf("no open pull requests found for branch %q", branch)
336370
}
337371

338372
func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{}) (*PullRequest, error) {

command/pr.go

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@ import (
55
"io"
66
"os"
77
"os/exec"
8+
"regexp"
89
"strconv"
10+
"strings"
911

1012
"github.com/github/gh-cli/api"
13+
"github.com/github/gh-cli/context"
1114
"github.com/github/gh-cli/git"
1215
"github.com/github/gh-cli/utils"
1316
"github.com/spf13/cobra"
@@ -65,7 +68,7 @@ func prStatus(cmd *cobra.Command, args []string) error {
6568
if err != nil {
6669
return err
6770
}
68-
currentBranch, err := ctx.Branch()
71+
currentPRNumber, currentPRHeadRef, err := prSelectorForCurrentBranch(ctx)
6972
if err != nil {
7073
return err
7174
}
@@ -74,7 +77,7 @@ func prStatus(cmd *cobra.Command, args []string) error {
7477
return err
7578
}
7679

77-
prPayload, err := api.PullRequests(apiClient, baseRepo, currentBranch, currentUser)
80+
prPayload, err := api.PullRequests(apiClient, baseRepo, currentPRNumber, currentPRHeadRef, currentUser)
7881
if err != nil {
7982
return err
8083
}
@@ -85,7 +88,7 @@ func prStatus(cmd *cobra.Command, args []string) error {
8588
if prPayload.CurrentPR != nil {
8689
printPrs(out, *prPayload.CurrentPR)
8790
} else {
88-
message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentBranch+"]"))
91+
message := fmt.Sprintf(" There is no pull request associated with %s", utils.Cyan("["+currentPRHeadRef+"]"))
8992
printMessage(out, message)
9093
}
9194
fmt.Fprintln(out)
@@ -217,28 +220,76 @@ func prView(cmd *cobra.Command, args []string) error {
217220
return fmt.Errorf("invalid pull request number: '%s'", args[0])
218221
}
219222
} else {
220-
apiClient, err := apiClientForContext(ctx)
221-
if err != nil {
222-
return err
223-
}
224-
currentBranch, err := ctx.Branch()
223+
prNumber, branchWithOwner, err := prSelectorForCurrentBranch(ctx)
225224
if err != nil {
226225
return err
227226
}
228227

229-
prs, err := api.PullRequestsForBranch(apiClient, baseRepo, currentBranch)
230-
if err != nil {
231-
return err
232-
} else if len(prs) < 1 {
233-
return fmt.Errorf("the '%s' branch has no open pull requests", currentBranch)
228+
if prNumber > 0 {
229+
openURL = fmt.Sprintf("https://github.com/%s/%s/pull/%d", baseRepo.RepoOwner(), baseRepo.RepoName(), prNumber)
230+
} else {
231+
apiClient, err := apiClientForContext(ctx)
232+
if err != nil {
233+
return err
234+
}
235+
236+
pr, err := api.PullRequestForBranch(apiClient, baseRepo, branchWithOwner)
237+
if err != nil {
238+
return err
239+
}
240+
openURL = pr.URL
234241
}
235-
openURL = prs[0].URL
236242
}
237243

238244
fmt.Printf("Opening %s in your browser.\n", openURL)
239245
return utils.OpenInBrowser(openURL)
240246
}
241247

248+
func prSelectorForCurrentBranch(ctx context.Context) (prNumber int, prHeadRef string, err error) {
249+
baseRepo, err := ctx.BaseRepo()
250+
if err != nil {
251+
return
252+
}
253+
prHeadRef, err = ctx.Branch()
254+
if err != nil {
255+
return
256+
}
257+
branchConfig := git.ReadBranchConfig(prHeadRef)
258+
259+
// the branch is configured to merge a special PR head ref
260+
prHeadRE := regexp.MustCompile(`^refs/pull/(\d+)/head$`)
261+
if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
262+
prNumber, _ = strconv.Atoi(m[1])
263+
return
264+
}
265+
266+
var branchOwner string
267+
if branchConfig.RemoteURL != nil {
268+
// the branch merges from a remote specified by URL
269+
if r, err := context.RepoFromURL(branchConfig.RemoteURL); err == nil {
270+
branchOwner = r.RepoOwner()
271+
}
272+
} else if branchConfig.RemoteName != "" {
273+
// the branch merges from a remote specified by name
274+
rem, _ := ctx.Remotes()
275+
if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
276+
branchOwner = r.RepoOwner()
277+
}
278+
}
279+
280+
if branchOwner != "" {
281+
if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
282+
prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
283+
}
284+
// prepend `OWNER:` if this branch is pushed to a fork
285+
if !strings.EqualFold(branchOwner, baseRepo.RepoOwner()) {
286+
prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
287+
}
288+
}
289+
290+
return
291+
}
292+
242293
func prCheckout(cmd *cobra.Command, args []string) error {
243294
prNumber, err := strconv.Atoi(args[0])
244295
if err != nil {

0 commit comments

Comments
 (0)