diff --git a/cmd/root_test.go b/cmd/root_test.go index 5eeadf5..1960213 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -60,6 +60,7 @@ func TestAllPermissionsAllowed(t *testing.T) { assertExec(t, client) assertPtyTerminal(t, client) assertSftp(t, client) + assertUnixRemotePortForwarding(t, client) assertUnixLocalPortForwarding(t, client) } @@ -167,6 +168,7 @@ func TestAllowExecute(t *testing.T) { assertExec(t, client) assertPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) assertNoUnixLocalPortForwarding(t, client) } @@ -197,6 +199,38 @@ func TestAllowTcpipForward(t *testing.T) { assertNoExec(t, client) assertNoPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) + assertNoUnixLocalPortForwarding(t, client) +} + +func TestAllowStreamlocalForward(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-streamlocal-forward"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoSftp(t, client) + assertUnixRemotePortForwarding(t, client) assertNoUnixLocalPortForwarding(t, client) } @@ -227,6 +261,7 @@ func TestAllowDirectTcpip(t *testing.T) { assertNoExec(t, client) assertNoPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) assertNoUnixLocalPortForwarding(t, client) } @@ -257,6 +292,7 @@ func TestAllowDirectStreamlocal(t *testing.T) { assertNoExec(t, client) assertNoPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) assertUnixLocalPortForwarding(t, client) } @@ -286,5 +322,6 @@ func TestAllowSftp(t *testing.T) { assertNoLocalPortForwarding(t, client) assertNoExec(t, client) assertNoPtyTerminal(t, client) + assertNoUnixRemotePortForwarding(t, client) assertSftp(t, client) } diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go index b870aaa..9d7e5ab 100644 --- a/cmd/test_util_test.go +++ b/cmd/test_util_test.go @@ -187,6 +187,7 @@ func assertRemotePortForwarding(t *testing.T, client *ssh.Client) { remotePort := getAvailableTcpPort() ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) assert.NoError(t, err) + defer ln.Close() acceptedConnChan := make(chan net.Conn) go func() { conn, err := ln.Accept() @@ -224,6 +225,50 @@ func assertNoRemotePortForwarding(t *testing.T, client *ssh.Client) { assert.Equal(t, "ssh: tcpip-forward request denied by peer", err.Error()) } +func assertUnixRemotePortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + ln, err := client.ListenUnix(remoteUnixSocket) + assert.NoError(t, err) + defer os.Remove(remoteUnixSocket) + defer ln.Close() + acceptedConnChan := make(chan net.Conn) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + conn, err := net.Dial("unix", remoteUnixSocket) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertNoUnixRemotePortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + _, err := client.ListenUnix(remoteUnixSocket) + assert.Error(t, err) + assert.Equal(t, "ssh: streamlocal-forward@openssh.com request denied by peer", err.Error()) +} + func assertSftp(t *testing.T, client *ssh.Client) { sftpClient, err := sftp.NewClient(client) assert.NoError(t, err)