Skip to content

Commit 42e47a9

Browse files
committed
add docs, simplify map, error on invalid args
1 parent 8a0f8b6 commit 42e47a9

File tree

4 files changed

+50
-32
lines changed

4 files changed

+50
-32
lines changed

cmd/ghcs/logs.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,12 @@ func logs(ctx context.Context, log *output.Logger, codespaceName string, follow
8282
}
8383

8484
dst := fmt.Sprintf("%s@localhost", sshUser)
85-
cmd := codespaces.NewRemoteCommand(
85+
cmd, err := codespaces.NewRemoteCommand(
8686
ctx, localPort, dst, fmt.Sprintf("%s /workspaces/.codespaces/.persistedshare/creation.log", cmdType),
8787
)
88+
if err != nil {
89+
return fmt.Errorf("remote command: %w", err)
90+
}
8891

8992
tunnelClosed := make(chan error, 1)
9093
go func() {

internal/codespaces/ssh.go

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package codespaces
22

33
import (
44
"context"
5+
"fmt"
56
"os"
67
"os/exec"
78
"strconv"
@@ -12,7 +13,10 @@ import (
1213
// port-forwarding session. It runs until the shell is terminated
1314
// (including by cancellation of the context).
1415
func Shell(ctx context.Context, log logger, sshArgs []string, port int, destination string, usingCustomPort bool) error {
15-
cmd, connArgs := newSSHCommand(ctx, port, destination, sshArgs)
16+
cmd, connArgs, err := newSSHCommand(ctx, port, destination, sshArgs)
17+
if err != nil {
18+
return fmt.Errorf("failed to create ssh command: %w", err)
19+
}
1620

1721
if usingCustomPort {
1822
log.Println("Connection Details: ssh " + destination + " " + strings.Join(connArgs, " "))
@@ -23,17 +27,27 @@ func Shell(ctx context.Context, log logger, sshArgs []string, port int, destinat
2327

2428
// NewRemoteCommand returns an exec.Cmd that will securely run a shell
2529
// command on the remote machine.
26-
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) *exec.Cmd {
27-
cmd, _ := newSSHCommand(ctx, tunnelPort, destination, sshArgs)
28-
return cmd
30+
func NewRemoteCommand(ctx context.Context, tunnelPort int, destination string, sshArgs ...string) (*exec.Cmd, error) {
31+
cmd, _, err := newSSHCommand(ctx, tunnelPort, destination, sshArgs)
32+
return cmd, err
2933
}
3034

3135
// newSSHCommand populates an exec.Cmd to run a command (or if blank,
3236
// an interactive shell) over ssh.
33-
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string) {
37+
func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string) (*exec.Cmd, []string, error) {
3438
connArgs := []string{"-p", strconv.Itoa(port), "-o", "NoHostAuthenticationForLocalhost=yes"}
3539

36-
cmdArgs, command := parseSSHArgs(cmdArgs)
40+
// The ssh command syntax is: ssh [flags] user@host command [args...]
41+
// There is no way to specify the user@host destination as a flag.
42+
// Unfortunately, that means we need to know which user-provided words are
43+
// SSH flags and which are command arguments so that we can place
44+
// them before or after the destination, and that means we need to know all
45+
// the flags and their arities.
46+
cmdArgs, command, err := parseSSHArgs(cmdArgs)
47+
if err != nil {
48+
return nil, []string{}, err
49+
}
50+
3751
cmdArgs = append(cmdArgs, connArgs...)
3852
cmdArgs = append(cmdArgs, "-C") // Compression
3953
cmdArgs = append(cmdArgs, dst) // user@host
@@ -47,30 +61,12 @@ func newSSHCommand(ctx context.Context, port int, dst string, cmdArgs []string)
4761
cmd.Stdin = os.Stdin
4862
cmd.Stderr = os.Stderr
4963

50-
return cmd, connArgs
64+
return cmd, connArgs, nil
5165
}
5266

53-
var sshArgumentFlags = map[string]bool{
54-
"-b": true,
55-
"-c": true,
56-
"-D": true,
57-
"-e": true,
58-
"-F": true,
59-
"-I": true,
60-
"-i": true,
61-
"-L": true,
62-
"-l": true,
63-
"-m": true,
64-
"-O": true,
65-
"-o": true,
66-
"-p": true,
67-
"-R": true,
68-
"-S": true,
69-
"-W": true,
70-
"-w": true,
71-
}
67+
var sshArgumentFlags = "-b-c-D-e-F-I-i-L-l-m-O-o-p-R-S-W-w"
7268

73-
func parseSSHArgs(sshArgs []string) ([]string, string) {
69+
func parseSSHArgs(sshArgs []string) ([]string, string, error) {
7470
var (
7571
cmdArgs []string
7672
command []string
@@ -80,8 +76,12 @@ func parseSSHArgs(sshArgs []string) ([]string, string) {
8076
for _, arg := range sshArgs {
8177
switch {
8278
case strings.HasPrefix(arg, "-"):
79+
if len(command) > 0 {
80+
return []string{}, "", fmt.Errorf("invalid flag after command: %s", arg)
81+
}
82+
8383
cmdArgs = append(cmdArgs, arg)
84-
if _, ok := sshArgumentFlags[arg]; ok {
84+
if strings.Contains(sshArgumentFlags, arg) {
8585
flagArgument = true
8686
}
8787
case flagArgument:
@@ -92,5 +92,5 @@ func parseSSHArgs(sshArgs []string) ([]string, string) {
9292
}
9393
}
9494

95-
return cmdArgs, strings.Join(command, " ")
95+
return cmdArgs, strings.Join(command, " "), nil
9696
}

internal/codespaces/ssh_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ func TestParseSSHArgs(t *testing.T) {
4848
}
4949

5050
for _, tcase := range testCases {
51-
args, command := parseSSHArgs(tcase.Args)
51+
args, command, err := parseSSHArgs(tcase.Args)
52+
if err != nil {
53+
t.Errorf("received unexpected error: %w", err)
54+
}
55+
5256
if len(args) != len(tcase.ParsedArgs) {
5357
t.Fatalf("args do not match length of expected args. %#v, got '%d', expected: '%d'", tcase, len(args), len(tcase.ParsedArgs))
5458
}
@@ -62,3 +66,10 @@ func TestParseSSHArgs(t *testing.T) {
6266
}
6367
}
6468
}
69+
70+
func TestParseSSHArgsError(t *testing.T) {
71+
_, _, err := parseSSHArgs([]string{"-X", "test", "-Y"})
72+
if err == nil {
73+
t.Error("expected an error for invalid args")
74+
}
75+
}

internal/codespaces/states.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,14 @@ func PollPostCreateStates(ctx context.Context, log logger, apiClient *api.API, u
8989
}
9090

9191
func getPostCreateOutput(ctx context.Context, tunnelPort int, codespace *api.Codespace, user string) ([]PostCreateState, error) {
92-
cmd := NewRemoteCommand(
92+
cmd, err := NewRemoteCommand(
9393
ctx, tunnelPort, fmt.Sprintf("%s@localhost", user),
9494
"cat /workspaces/.codespaces/shared/postCreateOutput.json",
9595
)
96+
if err != nil {
97+
return nil, fmt.Errorf("remote command: %w", err)
98+
}
99+
96100
stdout := new(bytes.Buffer)
97101
cmd.Stdout = stdout
98102
if err := cmd.Run(); err != nil {

0 commit comments

Comments
 (0)