Skip to content

Commit 8a0f8b6

Browse files
committed
parse ssh args and command
1 parent 514448d commit 8a0f8b6

File tree

2 files changed

+117
-10
lines changed

2 files changed

+117
-10
lines changed

internal/codespaces/ssh.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
// port-forwarding session. It runs until the shell is terminated
1313
// (including by cancellation of the context).
1414
func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error {
15-
cmd, connArgs := newSSHCommand(ctx, port, destination, "")
15+
cmd, connArgs := newSSHCommand(ctx, port, destination, sshArgs)
1616

1717
if usingCustomPort {
1818
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
@@ -23,23 +23,21 @@ func Shell(ctx context.Context, log logger, sshArgs []string, port int, destinat
2323

2424
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
2525
// command on the remote machine.
26-
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination, command string) *exec.Cmd {
27-
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, command)
26+
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) *exec.Cmd {
27+
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, sshArgs)
2828
return cmd
2929
}
3030

3131
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
3232
// an interactive shell) over ssh.
33-
func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cmd, []string) {
33+
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string) {
3434
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
3535

36-
cmdArgs := []string{dst, "-C"} // Always use Compression
37-
if command == "" {
38-
// if we are in a shell send X11 and X11Trust
39-
cmdArgs = append(cmdArgs, "-X", "-Y")
40-
}
41-
36+
cmdArgs, command := parseSSHArgs(cmdArgs)
4237
cmdArgs = append(cmdArgs, connArgs...)
38+
cmdArgs = append(cmdArgs, "-C") // Compression
39+
cmdArgs = append(cmdArgs, dst) // user@host
40+
4341
if command != "" {
4442
cmdArgs = append(cmdArgs, command)
4543
}
@@ -51,3 +49,48 @@ func newSSHCommand(ctx context.Context, port int, dst, command string) (*exec.Cm
5149

5250
return cmd, connArgs
5351
}
52+
53+
var sshArgumentFlags = map[string]bool{
54+
"-b": true,
55+
"-c": true,
56+
"-D": true,
57+
"-e": true,
58+
"-F": true,
59+
"-I": true,
60+
"-i": true,
61+
"-L": true,
62+
"-l": true,
63+
"-m": true,
64+
"-O": true,
65+
"-o": true,
66+
"-p": true,
67+
"-R": true,
68+
"-S": true,
69+
"-W": true,
70+
"-w": true,
71+
}
72+
73+
func parseSSHArgs(sshArgs []string) ([]string, string) {
74+
var (
75+
cmdArgs []string
76+
command []string
77+
flagArgument bool
78+
)
79+
80+
for _, arg := range sshArgs {
81+
switch {
82+
case strings.HasPrefix(arg, "-"):
83+
cmdArgs = append(cmdArgs, arg)
84+
if _, ok := sshArgumentFlags[arg]; ok {
85+
flagArgument = true
86+
}
87+
case flagArgument:
88+
cmdArgs = append(cmdArgs, arg)
89+
flagArgument = false
90+
default:
91+
command = append(command, arg)
92+
}
93+
}
94+
95+
return cmdArgs, strings.Join(command, " ")
96+
}

internal/codespaces/ssh_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package codespaces
2+
3+
import "testing"
4+
5+
func TestParseSSHArgs(t *testing.T) {
6+
type testCase struct {
7+
Args []string
8+
ParsedArgs []string
9+
Command string
10+
}
11+
12+
testCases := []testCase{
13+
{
14+
Args: []string{"-X", "-Y"},
15+
ParsedArgs: []string{"-X", "-Y"},
16+
Command: "",
17+
},
18+
{
19+
Args: []string{"-X", "-Y", "-o", "someoption=test"},
20+
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
21+
Command: "",
22+
},
23+
{
24+
Args: []string{"-X", "-Y", "-o", "someoption=test", "somecommand"},
25+
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
26+
Command: "somecommand",
27+
},
28+
{
29+
Args: []string{"-X", "-Y", "-o", "someoption=test", "echo", "test"},
30+
ParsedArgs: []string{"-X", "-Y", "-o", "someoption=test"},
31+
Command: "echo test",
32+
},
33+
{
34+
Args: []string{"somecommand"},
35+
ParsedArgs: []string{},
36+
Command: "somecommand",
37+
},
38+
{
39+
Args: []string{"echo", "test"},
40+
ParsedArgs: []string{},
41+
Command: "echo test",
42+
},
43+
{
44+
Args: []string{"-v", "echo", "hello", "world"},
45+
ParsedArgs: []string{"-v"},
46+
Command: "echo hello world",
47+
},
48+
}
49+
50+
for _, tcase := range testCases {
51+
args, command := parseSSHArgs(tcase.Args)
52+
if len(args) != len(tcase.ParsedArgs) {
53+
t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs))
54+
}
55+
for i, arg := range args {
56+
if arg != tcase.ParsedArgs[i] {
57+
t.Fatalf("arg does not match expected parsed arg. %v, got '%s', expected: '%s'", tcase, arg, tcase.ParsedArgs[i])
58+
}
59+
}
60+
if command != tcase.Command {
61+
t.Fatalf("command does not match expected command. %v, got: '%s', expected: '%s'", tcase, command, tcase.Command)
62+
}
63+
}
64+
}

0 commit comments

Comments
 (0)