add more tests

This commit is contained in:
Ryo Ota 2023-08-10 22:53:02 +09:00
parent 35c29a6de2
commit cad0e3fa23
3 changed files with 161 additions and 67 deletions

View file

@ -2,12 +2,11 @@ package cmd
import (
"bytes"
"context"
"github.com/nwtgck/handy-sshd/version"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"io"
"net"
"os/exec"
"strconv"
"testing"
)
@ -37,10 +36,12 @@ func TestAllPermissionsAllowed(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
assert.NoError(t, rootCmd.Execute())
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
@ -53,77 +54,66 @@ func TestAllPermissionsAllowed(t *testing.T) {
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertExec(t, client)
assertLocalPortForwarding(t, client)
assertRemotePortForwardingTODO(t, client)
}
func assertExec(t *testing.T, client *ssh.Client) {
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()
whoamiBytes, err := session.Output("whoami")
assert.NoError(t, err)
expectedWhoamiBytes, err := exec.Command("whoami").Output()
assert.NoError(t, err)
assert.Equal(t, string(whoamiBytes), string(expectedWhoamiBytes))
}
func assertLocalPortForwarding(t *testing.T, client *ssh.Client) {
var remoteTcpPort int
acceptedConnChan := make(chan net.Conn)
{
ln, err := net.Listen("tcp", ":0")
assert.NoError(t, err)
remoteTcpPort = ln.Addr().(*net.TCPAddr).Port
func TestMultipleUsers(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword1", "--user", "alex:mypassword2"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
acceptedConnChan <- conn
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
for _, user := range []struct {
name string
password string
}{{name: "john", password: "mypassword1"}, {name: "alex", password: "mypassword2"}} {
sshClientConfig := &ssh.ClientConfig{
User: user.name,
Auth: []ssh.AuthMethod{ssh.Password(user.password)},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort}
conn, err := client.DialTCP("tcp", nil, raddr)
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer conn.Close()
acceptedConn := <-acceptedConnChan
defer acceptedConn.Close()
{
localToRemote := [3]byte{1, 2, 3}
_, err = conn.Write(localToRemote[:])
assert.NoError(t, err)
var buf [len(localToRemote)]byte
_, err = io.ReadFull(acceptedConn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, localToRemote)
}
{
remoteToLocal := [4]byte{10, 20, 30, 40}
_, err = acceptedConn.Write(remoteToLocal[:])
assert.NoError(t, err)
var buf [len(remoteToLocal)]byte
_, err = io.ReadFull(conn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, remoteToLocal)
defer client.Close()
}
}
func getAvailableTcpPort() int {
ln, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer ln.Close()
return ln.Addr().(*net.TCPAddr).Port
}
func waitTCPServer(port int) {
for {
conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err == nil {
conn.Close()
break
}
func TestWrongPassword(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mywrongpassword")},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
_, err := ssh.Dial("tcp", address, sshClientConfig)
assert.Error(t, err)
assert.Equal(t, err.Error(), `ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain`)
}

103
cmd/test_util_test.go Normal file
View file

@ -0,0 +1,103 @@
package cmd
import (
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"io"
"net"
"os/exec"
"strconv"
"testing"
)
func getAvailableTcpPort() int {
ln, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer ln.Close()
return ln.Addr().(*net.TCPAddr).Port
}
func waitTCPServer(port int) {
for {
conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err == nil {
conn.Close()
break
}
}
}
func assertExec(t *testing.T, client *ssh.Client) {
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()
whoamiBytes, err := session.Output("whoami")
assert.NoError(t, err)
expectedWhoamiBytes, err := exec.Command("whoami").Output()
assert.NoError(t, err)
assert.Equal(t, string(whoamiBytes), string(expectedWhoamiBytes))
}
func assertLocalPortForwarding(t *testing.T, client *ssh.Client) {
var remoteTcpPort int
acceptedConnChan := make(chan net.Conn)
{
ln, err := net.Listen("tcp", ":0")
assert.NoError(t, err)
remoteTcpPort = ln.Addr().(*net.TCPAddr).Port
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
acceptedConnChan <- conn
}()
}
raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort}
conn, err := client.DialTCP("tcp", nil, raddr)
assert.NoError(t, err)
defer conn.Close()
acceptedConn := <-acceptedConnChan
defer acceptedConn.Close()
{
localToRemote := [3]byte{1, 2, 3}
_, err = conn.Write(localToRemote[:])
assert.NoError(t, err)
var buf [len(localToRemote)]byte
_, err = io.ReadFull(acceptedConn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, localToRemote)
}
{
remoteToLocal := [4]byte{10, 20, 30, 40}
_, err = acceptedConn.Write(remoteToLocal[:])
assert.NoError(t, err)
var buf [len(remoteToLocal)]byte
_, err = io.ReadFull(conn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, remoteToLocal)
}
}
func assertRemotePortForwardingTODO(t *testing.T, client *ssh.Client) {
remotePort := getAvailableTcpPort()
acceptedConnChan := make(chan net.Conn)
var _ = acceptedConnChan
ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort)))
var _ = ln
assert.NoError(t, err)
go func() {
//conn, err := ln.Accept()
//assert.NoError(t, err)
//acceptedConnChan <- conn
}()
conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort)))
assert.NoError(t, err)
defer conn.Close()
// FIXME: implement but the following suspends
//acceptedConn := <-acceptedConnChan
//defer acceptedConn.Close()
// TODO: conn <--> acceptedConn communication
}

View file

@ -293,6 +293,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
req.Reply(false, nil)
return
}
req.Reply(true, nil)
go func() {
sshConn.Wait()
ln.Close()