diff --git a/cmd/root_test.go b/cmd/root_test.go index 1865587..5eeadf5 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -60,6 +60,7 @@ func TestAllPermissionsAllowed(t *testing.T) { assertExec(t, client) assertPtyTerminal(t, client) assertSftp(t, client) + assertUnixLocalPortForwarding(t, client) } func TestEmptyPassword(t *testing.T) { @@ -166,6 +167,7 @@ func TestAllowExecute(t *testing.T) { assertExec(t, client) assertPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixLocalPortForwarding(t, client) } func TestAllowTcpipForward(t *testing.T) { @@ -195,6 +197,7 @@ func TestAllowTcpipForward(t *testing.T) { assertNoExec(t, client) assertNoPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixLocalPortForwarding(t, client) } func TestAllowDirectTcpip(t *testing.T) { @@ -224,6 +227,37 @@ func TestAllowDirectTcpip(t *testing.T) { assertNoExec(t, client) assertNoPtyTerminal(t, client) assertNoSftp(t, client) + assertNoUnixLocalPortForwarding(t, client) +} + +func TestAllowDirectStreamlocal(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-direct-streamlocal"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoSftp(t, client) + assertUnixLocalPortForwarding(t, client) } func TestAllowSftp(t *testing.T) { diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go index b801bf4..b870aaa 100644 --- a/cmd/test_util_test.go +++ b/cmd/test_util_test.go @@ -2,12 +2,15 @@ package cmd import ( "bytes" + "github.com/google/uuid" "github.com/pkg/sftp" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh" "io" "net" + "os" "os/exec" + "path" "strconv" "testing" "time" @@ -135,6 +138,51 @@ func assertNoLocalPortForwarding(t *testing.T, client *ssh.Client) { assert.Equal(t, "ssh: rejected: administratively prohibited (direct-tcpip not allowed)", err.Error()) } +func assertUnixLocalPortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + acceptedConnChan := make(chan net.Conn) + { + ln, err := net.Listen("unix", remoteUnixSocket) + assert.NoError(t, err) + defer os.Remove(remoteUnixSocket) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + } + conn, err := client.Dial("unix", remoteUnixSocket) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertNoUnixLocalPortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + _, err := client.Dial("unix", remoteUnixSocket) + assert.Error(t, err) + assert.Equal(t, "ssh: rejected: administratively prohibited (direct-streamlocal (Unix domain socket) not allowed)", err.Error()) +} + func assertRemotePortForwarding(t *testing.T, client *ssh.Client) { remotePort := getAvailableTcpPort() ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) diff --git a/go.mod b/go.mod index d7e0a84..de21a8d 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/creack/pty v1.1.18 + github.com/google/uuid v1.3.0 github.com/mattn/go-shellwords v1.0.12 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.5 diff --git a/go.sum b/go.sum index b17c2a7..c0e1b80 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=