Skip to content

Commit cec3aa2

Browse files
committed
Support detach head for pr checkout
1 parent d0a4639 commit cec3aa2

File tree

2 files changed

+144
-66
lines changed

2 files changed

+144
-66
lines changed

pkg/cmd/pr/checkout/checkout.go

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type CheckoutOptions struct {
3131
SelectorArg string
3232
RecurseSubmodules bool
3333
Force bool
34+
Detach bool
3435
}
3536

3637
func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command {
@@ -63,6 +64,7 @@ func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobr
6364

6465
cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout")
6566
cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request")
67+
cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD")
6668

6769
return cmd
6870
}
@@ -88,10 +90,9 @@ func checkoutRun(opts *CheckoutOptions) error {
8890
if err != nil {
8991
return err
9092
}
91-
protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
9293

94+
protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
9395
baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName())
94-
// baseRemoteSpec is a repository URL or a remote name to be used in git fetch
9596
baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol)
9697
if baseRemote != nil {
9798
baseURLOrName = baseRemote.Name
@@ -102,89 +103,131 @@ func checkoutRun(opts *CheckoutOptions) error {
102103
headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
103104
}
104105

105-
var cmdQueue [][]string
106-
newBranchName := pr.HeadRefName
107-
if strings.HasPrefix(newBranchName, "-") {
108-
return fmt.Errorf("invalid branch name: %q", newBranchName)
106+
if strings.HasPrefix(pr.HeadRefName, "-") {
107+
return fmt.Errorf("invalid branch name: %q", pr.HeadRefName)
109108
}
110109

110+
var cmdQueue [][]string
111+
111112
if headRemote != nil {
112-
// there is an existing git remote for PR head
113-
remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName)
114-
refSpec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s", pr.HeadRefName, remoteBranch)
115-
116-
cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec})
117-
118-
// local branch already exists
119-
if _, err := git.ShowRefs("refs/heads/" + newBranchName); err == nil {
120-
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
121-
if opts.Force {
122-
cmdQueue = append(cmdQueue, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
123-
} else {
124-
// TODO: check if non-fast-forward and suggest to use `--force`
125-
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
126-
}
127-
} else {
128-
cmdQueue = append(cmdQueue, []string{"git", "checkout", "-b", newBranchName, "--no-track", remoteBranch})
129-
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), headRemote.Name})
130-
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), "refs/heads/" + pr.HeadRefName})
131-
}
113+
cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...)
132114
} else {
133-
// no git remote for PR head
134-
currentBranch, _ := opts.Branch()
135-
136-
defaultBranchName, err := api.RepoDefaultBranch(apiClient, baseRepo)
115+
defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo)
137116
if err != nil {
138117
return err
139118
}
119+
cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...)
120+
}
140121

141-
// avoid naming the new branch the same as the default branch
142-
if newBranchName == defaultBranchName {
143-
newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName)
144-
}
122+
if opts.RecurseSubmodules {
123+
cmdQueue = append(cmdQueue, []string{"git", "submodule", "sync", "--recursive"})
124+
cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"})
125+
}
145126

146-
ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
147-
if newBranchName == currentBranch {
148-
// PR head matches currently checked out branch
127+
err = executeCmds(cmdQueue)
128+
if err != nil {
129+
return err
130+
}
149131

150-
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseURLOrName, ref})
132+
return nil
133+
}
151134

152-
if opts.Force {
153-
cmdQueue = append(cmdQueue, []string{"git", "reset", "--hard", "FETCH_HEAD"})
154-
} else {
155-
// TODO: check if non-fast-forward and suggest to use `--force`
156-
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
157-
}
158-
} else {
159-
// create a new branch
135+
func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string {
136+
var cmds [][]string
160137

161-
if opts.Force {
162-
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName), "--force"})
163-
} else {
164-
// TODO: check if non-fast-forward and suggest to use `--force`
165-
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName)})
166-
}
167-
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
138+
remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName)
139+
140+
refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName)
141+
if !opts.Detach {
142+
refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch)
143+
}
144+
145+
cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec})
146+
147+
switch {
148+
case opts.Detach:
149+
cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
150+
case localBranchExists(pr.HeadRefName):
151+
cmds = append(cmds, []string{"git", "checkout", pr.HeadRefName})
152+
if opts.Force {
153+
cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
154+
} else {
155+
// TODO: check if non-fast-forward and suggest to use `--force`
156+
cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
168157
}
158+
default:
159+
cmds = append(cmds, []string{"git", "checkout", "-b", pr.HeadRefName, "--no-track", remoteBranch})
160+
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", pr.HeadRefName), remote.Name})
161+
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", pr.HeadRefName), "refs/heads/" + pr.HeadRefName})
162+
}
169163

170-
remote := baseURLOrName
171-
mergeRef := ref
172-
if pr.MaintainerCanModify {
173-
headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, baseRepo.RepoHost())
174-
remote = ghrepo.FormatRemoteURL(headRepo, protocol)
175-
mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
164+
return cmds
165+
}
166+
167+
func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string {
168+
var cmds [][]string
169+
170+
newBranchName := pr.HeadRefName
171+
// avoid naming the new branch the same as the default branch
172+
if newBranchName == defaultBranch {
173+
newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName)
174+
}
175+
176+
ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
177+
178+
if opts.Detach {
179+
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
180+
cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
181+
return cmds
182+
}
183+
184+
currentBranch, _ := opts.Branch()
185+
if newBranchName == currentBranch {
186+
// PR head matches currently checked out branch
187+
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
188+
if opts.Force {
189+
cmds = append(cmds, []string{"git", "reset", "--hard", "FETCH_HEAD"})
190+
} else {
191+
// TODO: check if non-fast-forward and suggest to use `--force`
192+
cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
176193
}
177-
if mc, err := git.Config(fmt.Sprintf("branch.%s.merge", newBranchName)); err != nil || mc == "" {
178-
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote})
179-
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef})
194+
} else {
195+
// create a new branch
196+
if opts.Force {
197+
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName), "--force"})
198+
} else {
199+
// TODO: check if non-fast-forward and suggest to use `--force`
200+
cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, newBranchName)})
180201
}
202+
cmds = append(cmds, []string{"git", "checkout", newBranchName})
181203
}
182204

183-
if opts.RecurseSubmodules {
184-
cmdQueue = append(cmdQueue, []string{"git", "submodule", "sync", "--recursive"})
185-
cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"})
205+
remote := baseURLOrName
206+
mergeRef := ref
207+
if pr.MaintainerCanModify {
208+
headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost)
209+
remote = ghrepo.FormatRemoteURL(headRepo, protocol)
210+
mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
211+
}
212+
if missingMergeConfigForBranch(newBranchName) {
213+
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote})
214+
cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef})
186215
}
187216

217+
return cmds
218+
}
219+
220+
func missingMergeConfigForBranch(b string) bool {
221+
mc, err := git.Config(fmt.Sprintf("branch.%s.merge", b))
222+
return err != nil || mc == ""
223+
}
224+
225+
func localBranchExists(b string) bool {
226+
_, err := git.ShowRefs("refs/heads/" + b)
227+
return err == nil
228+
}
229+
230+
func executeCmds(cmdQueue [][]string) error {
188231
for _, args := range cmdQueue {
189232
// TODO: reuse the result of this lookup across loop iteration
190233
exe, err := safeexec.LookPath(args[0])
@@ -198,6 +241,5 @@ func checkoutRun(opts *CheckoutOptions) error {
198241
return err
199242
}
200243
}
201-
202244
return nil
203245
}

pkg/cmd/pr/checkout/checkout_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,3 +688,39 @@ func TestPRCheckout_force(t *testing.T) {
688688
assert.Equal(t, "git checkout feature", strings.Join(ranCommands[1], " "))
689689
assert.Equal(t, "git reset --hard refs/remotes/origin/feature", strings.Join(ranCommands[2], " "))
690690
}
691+
692+
func TestPRCheckout_detach(t *testing.T) {
693+
http := &httpmock.Registry{}
694+
defer http.Verify(t)
695+
696+
http.Register(httpmock.GraphQL(`query PullRequestByNumber\b`), httpmock.StringResponse(`
697+
{ "data": { "repository": { "pullRequest": {
698+
"number": 123,
699+
"headRef": "f8f8f8",
700+
"headRepositoryOwner": {
701+
"login": "hubot"
702+
},
703+
"headRepository": {
704+
"name": "REPO"
705+
},
706+
"isCrossRepository": true,
707+
"maintainerCanModify": true
708+
} } } }
709+
`))
710+
711+
ranCommands := [][]string{}
712+
//nolint:staticcheck // SA1019 TODO: rewrite to use run.Stub
713+
restoreCmd := run.SetPrepareCmd(func(cmd *exec.Cmd) run.Runnable {
714+
ranCommands = append(ranCommands, cmd.Args)
715+
return &test.OutputStub{}
716+
})
717+
defer restoreCmd()
718+
719+
output, err := runCommand(http, nil, "", `123 --detach`)
720+
assert.Nil(t, err)
721+
assert.Equal(t, "", output.String())
722+
723+
assert.Equal(t, 2, len(ranCommands))
724+
assert.Equal(t, "git fetch origin refs/pull/123/head", strings.Join(ranCommands[0], " "))
725+
assert.Equal(t, "git checkout --detach FETCH_HEAD", strings.Join(ranCommands[1], " "))
726+
}

0 commit comments

Comments
 (0)