From cad0e3fa233778d6553e72180a277c8063cb0b4b Mon Sep 17 00:00:00 2001 From: Ryo Ota Date: Thu, 10 Aug 2023 22:53:02 +0900 Subject: [PATCH] add more tests --- cmd/root_test.go | 124 +++++++++++++++++++----------------------- cmd/test_util_test.go | 103 +++++++++++++++++++++++++++++++++++ server.go | 1 + 3 files changed, 161 insertions(+), 67 deletions(-) create mode 100644 cmd/test_util_test.go diff --git a/cmd/root_test.go b/cmd/root_test.go index b0c0546..acfcce5 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -2,12 +2,11 @@ package cmd import ( "bytes" + "context" "github.com/nwtgck/handy-sshd/version" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh" - "io" "net" - "os/exec" "strconv" "testing" ) @@ -37,10 +36,12 @@ func TestAllPermissionsAllowed(t *testing.T) { rootCmd := RootCmd() port := getAvailableTcpPort() rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() go func() { var stderrBuf bytes.Buffer rootCmd.SetErr(&stderrBuf) - assert.NoError(t, rootCmd.Execute()) + rootCmd.ExecuteContext(ctx) }() waitTCPServer(port) sshClientConfig := &ssh.ClientConfig{ @@ -53,77 +54,66 @@ func TestAllPermissionsAllowed(t *testing.T) { address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) defer client.Close() assert.NoError(t, err) assertExec(t, client) assertLocalPortForwarding(t, client) + assertRemotePortForwardingTODO(t, client) } -func assertExec(t *testing.T, client *ssh.Client) { - session, err := client.NewSession() - assert.NoError(t, err) - defer session.Close() - whoamiBytes, err := session.Output("whoami") - assert.NoError(t, err) - expectedWhoamiBytes, err := exec.Command("whoami").Output() - assert.NoError(t, err) - assert.Equal(t, string(whoamiBytes), string(expectedWhoamiBytes)) -} +func TestMultipleUsers(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword1", "--user", "alex:mypassword2"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) -func assertLocalPortForwarding(t *testing.T, client *ssh.Client) { - var remoteTcpPort int - acceptedConnChan := make(chan net.Conn) - { - ln, err := net.Listen("tcp", ":0") - assert.NoError(t, err) - remoteTcpPort = ln.Addr().(*net.TCPAddr).Port - go func() { - conn, err := ln.Accept() - assert.NoError(t, err) - acceptedConnChan <- conn - }() - } - raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort} - conn, err := client.DialTCP("tcp", nil, raddr) - assert.NoError(t, err) - defer conn.Close() - acceptedConn := <-acceptedConnChan - defer acceptedConn.Close() - { - localToRemote := [3]byte{1, 2, 3} - _, err = conn.Write(localToRemote[:]) - assert.NoError(t, err) - var buf [len(localToRemote)]byte - _, err = io.ReadFull(acceptedConn, buf[:]) - assert.NoError(t, err) - assert.Equal(t, buf, localToRemote) - } - { - remoteToLocal := [4]byte{10, 20, 30, 40} - _, err = acceptedConn.Write(remoteToLocal[:]) - assert.NoError(t, err) - var buf [len(remoteToLocal)]byte - _, err = io.ReadFull(conn, buf[:]) - assert.NoError(t, err) - assert.Equal(t, buf, remoteToLocal) - } -} - -func getAvailableTcpPort() int { - ln, err := net.Listen("tcp", ":0") - if err != nil { - panic(err) - } - defer ln.Close() - return ln.Addr().(*net.TCPAddr).Port -} - -func waitTCPServer(port int) { - for { - conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) - if err == nil { - conn.Close() - break + for _, user := range []struct { + name string + password string + }{{name: "john", password: "mypassword1"}, {name: "alex", password: "mypassword2"}} { + sshClientConfig := &ssh.ClientConfig{ + User: user.name, + Auth: []ssh.AuthMethod{ssh.Password(user.password)}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, } + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() } } + +func TestWrongPassword(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mywrongpassword")}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + _, err := ssh.Dial("tcp", address, sshClientConfig) + assert.Error(t, err) + assert.Equal(t, err.Error(), `ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain`) +} diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go new file mode 100644 index 0000000..ef0f62a --- /dev/null +++ b/cmd/test_util_test.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" + "io" + "net" + "os/exec" + "strconv" + "testing" +) + +func getAvailableTcpPort() int { + ln, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + defer ln.Close() + return ln.Addr().(*net.TCPAddr).Port +} + +func waitTCPServer(port int) { + for { + conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + if err == nil { + conn.Close() + break + } + } +} + +func assertExec(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + whoamiBytes, err := session.Output("whoami") + assert.NoError(t, err) + expectedWhoamiBytes, err := exec.Command("whoami").Output() + assert.NoError(t, err) + assert.Equal(t, string(whoamiBytes), string(expectedWhoamiBytes)) +} + +func assertLocalPortForwarding(t *testing.T, client *ssh.Client) { + var remoteTcpPort int + acceptedConnChan := make(chan net.Conn) + { + ln, err := net.Listen("tcp", ":0") + assert.NoError(t, err) + remoteTcpPort = ln.Addr().(*net.TCPAddr).Port + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + } + raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort} + conn, err := client.DialTCP("tcp", nil, raddr) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertRemotePortForwardingTODO(t *testing.T, client *ssh.Client) { + remotePort := getAvailableTcpPort() + acceptedConnChan := make(chan net.Conn) + var _ = acceptedConnChan + ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) + var _ = ln + assert.NoError(t, err) + go func() { + //conn, err := ln.Accept() + //assert.NoError(t, err) + //acceptedConnChan <- conn + }() + + conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) + assert.NoError(t, err) + defer conn.Close() + + // FIXME: implement but the following suspends + //acceptedConn := <-acceptedConnChan + //defer acceptedConn.Close() + // TODO: conn <--> acceptedConn communication +} diff --git a/server.go b/server.go index 2022c4b..77f86ca 100644 --- a/server.go +++ b/server.go @@ -293,6 +293,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { req.Reply(false, nil) return } + req.Reply(true, nil) go func() { sshConn.Wait() ln.Close()