diff --git a/cmd/root_test.go b/cmd/root_test.go index dc00f86..34de772 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -61,7 +61,7 @@ func TestAllPermissionsAllowed(t *testing.T) { assertLocalPortForwarding(t, client) assertExec(t, client) assertPtyTerminal(t, client) - // TODO: sftp + assertSftp(t, client) } func TestEmptyPassword(t *testing.T) { @@ -173,7 +173,7 @@ func TestAllowExecute(t *testing.T) { assertNoLocalPortForwarding(t, client) assertExec(t, client) assertPtyTerminal(t, client) - // TODO: no sftp + assertNoSftp(t, client) } func TestAllowTcpipForward(t *testing.T) { @@ -204,7 +204,7 @@ func TestAllowTcpipForward(t *testing.T) { assertNoLocalPortForwarding(t, client) assertNoExec(t, client) assertNoPtyTerminal(t, client) - // TODO: no sftp + assertNoSftp(t, client) } func TestAllowDirectTcpip(t *testing.T) { @@ -235,7 +235,36 @@ func TestAllowDirectTcpip(t *testing.T) { assertLocalPortForwarding(t, client) assertNoExec(t, client) assertNoPtyTerminal(t, client) - // TODO: no sftp + assertNoSftp(t, client) } -// TODO: TestAllowSftp +func TestAllowSftp(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-sftp"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertSftp(t, client) +} diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go index 016a7f5..101e7df 100644 --- a/cmd/test_util_test.go +++ b/cmd/test_util_test.go @@ -2,6 +2,7 @@ package cmd import ( "bytes" + "github.com/pkg/sftp" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh" "io" @@ -162,3 +163,16 @@ func assertNoRemotePortForwarding(t *testing.T, client *ssh.Client) { assert.Error(t, err) assert.Equal(t, "ssh: tcpip-forward request denied by peer", err.Error()) } + +func assertSftp(t *testing.T, client *ssh.Client) { + sftpClient, err := sftp.NewClient(client) + assert.NoError(t, err) + _, err = sftpClient.Getwd() + assert.NoError(t, err) +} + +func assertNoSftp(t *testing.T, client *ssh.Client) { + _, err := sftp.NewClient(client) + assert.Error(t, err) + assert.Equal(t, "ssh: subsystem request failed", err.Error()) +}