diff --git a/cmd/root_test.go b/cmd/root_test.go index acfcce5..1f292c9 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -60,6 +60,33 @@ func TestAllPermissionsAllowed(t *testing.T) { assertExec(t, client) assertLocalPortForwarding(t, client) assertRemotePortForwardingTODO(t, client) + // TODO: pty + // TODO: sftp +} + +func TestEmptyPassword(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() } func TestMultipleUsers(t *testing.T) { @@ -115,5 +142,100 @@ func TestWrongPassword(t *testing.T) { address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) _, err := ssh.Dial("tcp", address, sshClientConfig) assert.Error(t, err) - assert.Equal(t, err.Error(), `ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain`) + assert.Equal(t, `ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain`, err.Error()) } + +func TestAllowExecute(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-execute"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertExec(t, client) + // TODO: no pty + // TODO: no sftp +} + +func TestAllowTcpipForward(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-tcpip-forward"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertRemotePortForwardingTODO(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + // TODO: no pty + // TODO: no sftp +} + +func TestAllowDirectTcpip(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-direct-tcpip"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertLocalPortForwarding(t, client) + assertNoExec(t, client) + // TODO: no pty + // TODO: no sftp +} + +// TODO: TestAllowSftp diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go index ef0f62a..53b3450 100644 --- a/cmd/test_util_test.go +++ b/cmd/test_util_test.go @@ -37,7 +37,16 @@ func assertExec(t *testing.T, client *ssh.Client) { assert.NoError(t, err) expectedWhoamiBytes, err := exec.Command("whoami").Output() assert.NoError(t, err) - assert.Equal(t, string(whoamiBytes), string(expectedWhoamiBytes)) + assert.Equal(t, string(expectedWhoamiBytes), string(whoamiBytes)) +} + +func assertNoExec(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + _, err = session.Output("whoami") + assert.Error(t, err) + assert.Equal(t, "ssh: command whoami failed", err.Error()) } func assertLocalPortForwarding(t *testing.T, client *ssh.Client) { @@ -79,6 +88,13 @@ func assertLocalPortForwarding(t *testing.T, client *ssh.Client) { } } +func assertNoLocalPortForwarding(t *testing.T, client *ssh.Client) { + raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + _, err := client.DialTCP("tcp", nil, raddr) + assert.Error(t, err) + assert.Equal(t, "ssh: rejected: administratively prohibited (direct-tcpip not allowed)", err.Error()) +} + func assertRemotePortForwardingTODO(t *testing.T, client *ssh.Client) { remotePort := getAvailableTcpPort() acceptedConnChan := make(chan net.Conn) @@ -101,3 +117,9 @@ func assertRemotePortForwardingTODO(t *testing.T, client *ssh.Client) { //defer acceptedConn.Close() // TODO: conn <--> acceptedConn communication } + +func assertNoRemotePortForwarding(t *testing.T, client *ssh.Client) { + _, err := client.Listen("tcp", "127.0.0.1:5678") + assert.Error(t, err) + assert.Equal(t, "ssh: tcpip-forward request denied by peer", err.Error()) +}