|
4 | 4 | "errors" |
5 | 5 | "fmt" |
6 | 6 | "net/url" |
| 7 | + "strings" |
7 | 8 | "time" |
8 | 9 |
|
9 | 10 | "github.com/cli/cli/api" |
@@ -75,7 +76,27 @@ func prCreate(cmd *cobra.Command, _ []string) error { |
75 | 76 | if err != nil { |
76 | 77 | return fmt.Errorf("could not determine the current branch: %w", err) |
77 | 78 | } |
78 | | - headRepo, headRepoErr := repoContext.HeadRepo() |
| 79 | + |
| 80 | + var headRepo ghrepo.Interface |
| 81 | + var headRemote *context.Remote |
| 82 | + |
| 83 | + // determine whether the head branch is already pushed to a remote |
| 84 | + headBranchPushedTo := determineTrackingBranch(remotes, headBranch) |
| 85 | + if headBranchPushedTo != nil { |
| 86 | + for _, r := range remotes { |
| 87 | + if r.Name != headBranchPushedTo.RemoteName { |
| 88 | + continue |
| 89 | + } |
| 90 | + headRepo = r |
| 91 | + headRemote = r |
| 92 | + break |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + // otherwise, determine the head repository with info obtained from the API |
| 97 | + if headRepo == nil { |
| 98 | + headRepo, _ = repoContext.HeadRepo() |
| 99 | + } |
79 | 100 |
|
80 | 101 | baseBranch, err := cmd.Flags().GetString("base") |
81 | 102 | if err != nil { |
@@ -193,8 +214,9 @@ func prCreate(cmd *cobra.Command, _ []string) error { |
193 | 214 | } |
194 | 215 |
|
195 | 216 | didForkRepo := false |
196 | | - var headRemote *context.Remote |
197 | | - if headRepoErr != nil { |
| 217 | + // if a head repository could not be determined so far, automatically create |
| 218 | + // one by forking the base repository |
| 219 | + if headRepo == nil { |
198 | 220 | if baseRepo.IsPrivate { |
199 | 221 | return fmt.Errorf("cannot fork private repository '%s'", ghrepo.FullName(baseRepo)) |
200 | 222 | } |
@@ -223,29 +245,31 @@ func prCreate(cmd *cobra.Command, _ []string) error { |
223 | 245 | headBranchLabel = fmt.Sprintf("%s:%s", headRepo.RepoOwner(), headBranch) |
224 | 246 | } |
225 | 247 |
|
226 | | - if headRemote == nil { |
227 | | - headRemote, err = repoContext.RemoteForRepo(headRepo) |
228 | | - if err != nil { |
229 | | - return fmt.Errorf("git remote not found for head repository: %w", err) |
| 248 | + // automatically push the branch if it hasn't been pushed anywhere yet |
| 249 | + if headBranchPushedTo == nil { |
| 250 | + if headRemote == nil { |
| 251 | + headRemote, err = repoContext.RemoteForRepo(headRepo) |
| 252 | + if err != nil { |
| 253 | + return fmt.Errorf("git remote not found for head repository: %w", err) |
| 254 | + } |
230 | 255 | } |
231 | | - } |
232 | 256 |
|
233 | | - pushTries := 0 |
234 | | - maxPushTries := 3 |
235 | | - for { |
236 | | - // TODO: respect existing upstream configuration of the current branch |
237 | | - if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil { |
238 | | - if didForkRepo && pushTries < maxPushTries { |
239 | | - pushTries++ |
240 | | - // first wait 2 seconds after forking, then 4s, then 6s |
241 | | - waitSeconds := 2 * pushTries |
242 | | - fmt.Fprintf(cmd.ErrOrStderr(), "waiting %s before retrying...\n", utils.Pluralize(waitSeconds, "second")) |
243 | | - time.Sleep(time.Duration(waitSeconds) * time.Second) |
244 | | - continue |
| 257 | + pushTries := 0 |
| 258 | + maxPushTries := 3 |
| 259 | + for { |
| 260 | + if err := git.Push(headRemote.Name, fmt.Sprintf("HEAD:%s", headBranch)); err != nil { |
| 261 | + if didForkRepo && pushTries < maxPushTries { |
| 262 | + pushTries++ |
| 263 | + // first wait 2 seconds after forking, then 4s, then 6s |
| 264 | + waitSeconds := 2 * pushTries |
| 265 | + fmt.Fprintf(cmd.ErrOrStderr(), "waiting %s before retrying...\n", utils.Pluralize(waitSeconds, "second")) |
| 266 | + time.Sleep(time.Duration(waitSeconds) * time.Second) |
| 267 | + continue |
| 268 | + } |
| 269 | + return err |
245 | 270 | } |
246 | | - return err |
| 271 | + break |
247 | 272 | } |
248 | | - break |
249 | 273 | } |
250 | 274 |
|
251 | 275 | if action == SubmitAction { |
@@ -275,6 +299,47 @@ func prCreate(cmd *cobra.Command, _ []string) error { |
275 | 299 | return nil |
276 | 300 | } |
277 | 301 |
|
| 302 | +func determineTrackingBranch(remotes context.Remotes, headBranch string) *git.TrackingRef { |
| 303 | + refsForLookup := []string{"HEAD"} |
| 304 | + var trackingRefs []git.TrackingRef |
| 305 | + |
| 306 | + headBranchConfig := git.ReadBranchConfig(headBranch) |
| 307 | + if headBranchConfig.RemoteName != "" { |
| 308 | + tr := git.TrackingRef{ |
| 309 | + RemoteName: headBranchConfig.RemoteName, |
| 310 | + BranchName: strings.TrimPrefix(headBranchConfig.MergeRef, "refs/heads/"), |
| 311 | + } |
| 312 | + trackingRefs = append(trackingRefs, tr) |
| 313 | + refsForLookup = append(refsForLookup, tr.String()) |
| 314 | + } |
| 315 | + |
| 316 | + for _, remote := range remotes { |
| 317 | + tr := git.TrackingRef{ |
| 318 | + RemoteName: remote.Name, |
| 319 | + BranchName: headBranch, |
| 320 | + } |
| 321 | + trackingRefs = append(trackingRefs, tr) |
| 322 | + refsForLookup = append(refsForLookup, tr.String()) |
| 323 | + } |
| 324 | + |
| 325 | + resolvedRefs, _ := git.ShowRefs(refsForLookup...) |
| 326 | + if len(resolvedRefs) > 1 { |
| 327 | + for _, r := range resolvedRefs[1:] { |
| 328 | + if r.Hash != resolvedRefs[0].Hash { |
| 329 | + continue |
| 330 | + } |
| 331 | + for _, tr := range trackingRefs { |
| 332 | + if tr.String() != r.Name { |
| 333 | + continue |
| 334 | + } |
| 335 | + return &tr |
| 336 | + } |
| 337 | + } |
| 338 | + } |
| 339 | + |
| 340 | + return nil |
| 341 | +} |
| 342 | + |
278 | 343 | func generateCompareURL(r ghrepo.Interface, base, head, title, body string) string { |
279 | 344 | u := fmt.Sprintf( |
280 | 345 | "https://github.com/%s/compare/%s...%s?expand=1", |
|
0 commit comments