Skip to content

Commit 7a614ce

Browse files
committed
Support triangular git workflows in pr create
- The local git remotes are scanned and resolved to GitHub repositories - The "base" repo is the first result resolved to its parent repo (if a fork) - The name of the default branch is read from the base repo - The "head" repo is the first repo that has push access
1 parent 99c17c3 commit 7a614ce

File tree

5 files changed

+349
-69
lines changed

5 files changed

+349
-69
lines changed

api/client.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"io/ioutil"
99
"net/http"
10+
"strings"
1011
)
1112

1213
// ClientOption represents an argument to NewClient
@@ -69,9 +70,27 @@ type Client struct {
6970

7071
type graphQLResponse struct {
7172
Data interface{}
72-
Errors []struct {
73-
Message string
73+
Errors []GraphQLError
74+
}
75+
76+
// GraphQLError is a single error returned in a GraphQL response
77+
type GraphQLError struct {
78+
Type string
79+
Path []string
80+
Message string
81+
}
82+
83+
// GraphQLErrorResponse contains errors returned in a GraphQL response
84+
type GraphQLErrorResponse struct {
85+
Errors []GraphQLError
86+
}
87+
88+
func (gr GraphQLErrorResponse) Error() string {
89+
errorMessages := make([]string, 0, len(gr.Errors))
90+
for _, e := range gr.Errors {
91+
errorMessages = append(errorMessages, e.Message)
7492
}
93+
return fmt.Sprintf("graphql error: '%s'", strings.Join(errorMessages, ", "))
7594
}
7695

7796
// GraphQL performs a GraphQL request and parses the response
@@ -151,14 +170,9 @@ func handleResponse(resp *http.Response, data interface{}) error {
151170
}
152171

153172
if len(gr.Errors) > 0 {
154-
errorMessages := gr.Errors[0].Message
155-
for _, e := range gr.Errors[1:] {
156-
errorMessages += ", " + e.Message
157-
}
158-
return fmt.Errorf("graphql error: '%s'", errorMessages)
173+
return &GraphQLErrorResponse{Errors: gr.Errors}
159174
}
160175
return nil
161-
162176
}
163177

164178
func handleHTTPError(resp *http.Response) error {

api/queries_pr.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,8 @@ func PullRequestForBranch(client *Client, ghRepo Repo, branch string) (*PullRequ
403403
return nil, &NotFoundError{fmt.Errorf("no open pull requests found for branch %q", branch)}
404404
}
405405

406-
func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{}) (*PullRequest, error) {
407-
repo, err := GitHubRepo(client, ghRepo)
408-
if err != nil {
409-
return nil, err
410-
}
411-
406+
// CreatePullRequest creates a pull request in a GitHub repository
407+
func CreatePullRequest(client *Client, repo Repository, params map[string]interface{}) (*PullRequest, error) {
412408
query := `
413409
mutation CreatePullRequest($input: CreatePullRequestInput!) {
414410
createPullRequest(input: $input) {
@@ -434,7 +430,7 @@ func CreatePullRequest(client *Client, ghRepo Repo, params map[string]interface{
434430
}
435431
}{}
436432

437-
err = client.GraphQL(query, variables, &result)
433+
err := client.GraphQL(query, variables, &result)
438434
if err != nil {
439435
return nil, err
440436
}

api/queries_repo.go

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,59 @@
11
package api
22

33
import (
4+
"bytes"
5+
"encoding/json"
46
"fmt"
7+
"sort"
8+
"strings"
59

610
"github.com/pkg/errors"
711
)
812

913
// Repository contains information about a GitHub repo
1014
type Repository struct {
11-
ID string
15+
ID string
16+
Name string
17+
Owner struct {
18+
Login string
19+
}
20+
21+
IsPrivate bool
1222
HasIssuesEnabled bool
23+
ViewerPermission string
24+
DefaultBranchRef struct {
25+
Name string
26+
Target struct {
27+
OID string
28+
}
29+
}
30+
31+
Parent *Repository
32+
}
33+
34+
// RepoOwner is the login name of the owner
35+
func (r Repository) RepoOwner() string {
36+
return r.Owner.Login
37+
}
38+
39+
// RepoName is the name of the repository
40+
func (r Repository) RepoName() string {
41+
return r.Name
42+
}
43+
44+
// IsFork is true when this repository has a parent repository
45+
func (r Repository) IsFork() bool {
46+
return r.Parent != nil
47+
}
48+
49+
// ViewerCanPush is true when the requesting user has push access
50+
func (r Repository) ViewerCanPush() bool {
51+
switch r.ViewerPermission {
52+
case "ADMIN", "MAINTAIN", "WRITE":
53+
return true
54+
default:
55+
return false
56+
}
1357
}
1458

1559
// GitHubRepo looks up the node ID of a named repository
@@ -44,3 +88,102 @@ func GitHubRepo(client *Client, ghRepo Repo) (*Repository, error) {
4488

4589
return &result.Repository, nil
4690
}
91+
92+
// RepoNetworkResult describes the relationship between related repositories
93+
type RepoNetworkResult struct {
94+
ViewerLogin string
95+
Repositories []*Repository
96+
}
97+
98+
// RepoNetwork inspects the relationship between multiple GitHub repositories
99+
func RepoNetwork(client *Client, repos []Repo) (RepoNetworkResult, error) {
100+
queries := []string{}
101+
for i, repo := range repos {
102+
queries = append(queries, fmt.Sprintf(`
103+
repo_%03d: repository(owner: %q, name: %q) {
104+
...repo
105+
parent {
106+
...repo
107+
}
108+
}
109+
`, i, repo.RepoOwner(), repo.RepoName()))
110+
}
111+
112+
type ViewerOrRepo struct {
113+
Login string
114+
Repository
115+
}
116+
117+
graphqlResult := map[string]*json.RawMessage{}
118+
result := RepoNetworkResult{}
119+
120+
err := client.GraphQL(fmt.Sprintf(`
121+
fragment repo on Repository {
122+
id
123+
name
124+
owner { login }
125+
viewerPermission
126+
defaultBranchRef {
127+
name
128+
target { oid }
129+
}
130+
isPrivate
131+
}
132+
query {
133+
viewer { login }
134+
%s
135+
}
136+
`, strings.Join(queries, "")), nil, &graphqlResult)
137+
graphqlError, isGraphQLError := err.(*GraphQLErrorResponse)
138+
if isGraphQLError {
139+
// If the only errors are that certain repositories are not found,
140+
// continue processing this response instead of returning an error
141+
tolerated := true
142+
for _, ge := range graphqlError.Errors {
143+
if ge.Type != "NOT_FOUND" {
144+
tolerated = false
145+
}
146+
}
147+
if tolerated {
148+
err = nil
149+
}
150+
}
151+
if err != nil {
152+
return result, err
153+
}
154+
155+
keys := []string{}
156+
for key := range graphqlResult {
157+
keys = append(keys, key)
158+
}
159+
// sort keys to ensure `repo_{N}` entries are processed in order
160+
sort.Sort(sort.StringSlice(keys))
161+
162+
for _, name := range keys {
163+
jsonMessage := graphqlResult[name]
164+
if name == "viewer" {
165+
viewerResult := struct {
166+
Login string
167+
}{}
168+
decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage)))
169+
if err := decoder.Decode(&viewerResult); err != nil {
170+
return result, err
171+
}
172+
result.ViewerLogin = viewerResult.Login
173+
} else if strings.HasPrefix(name, "repo_") {
174+
if jsonMessage == nil {
175+
result.Repositories = append(result.Repositories, nil)
176+
continue
177+
}
178+
repo := Repository{}
179+
decoder := json.NewDecoder(bytes.NewReader([]byte(*jsonMessage)))
180+
if err := decoder.Decode(&repo); err != nil {
181+
return result, err
182+
}
183+
result.Repositories = append(result.Repositories, &repo)
184+
} else {
185+
return result, fmt.Errorf("unknown GraphQL result key %q", name)
186+
}
187+
}
188+
return result, nil
189+
}

0 commit comments

Comments
 (0)