diff --git a/cmd/root_test.go b/cmd/root_test.go index 27c4b76..1865587 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -55,7 +55,7 @@ func TestAllPermissionsAllowed(t *testing.T) { assert.NoError(t, err) defer client.Close() assert.NoError(t, err) - assertRemotePortForwardingTODO(t, client) + assertRemotePortForwarding(t, client) assertLocalPortForwarding(t, client) assertExec(t, client) assertPtyTerminal(t, client) @@ -190,7 +190,7 @@ func TestAllowTcpipForward(t *testing.T) { assert.NoError(t, err) defer client.Close() assert.NoError(t, err) - assertRemotePortForwardingTODO(t, client) + assertRemotePortForwarding(t, client) assertNoLocalPortForwarding(t, client) assertNoExec(t, client) assertNoPtyTerminal(t, client) diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go index 101e7df..b801bf4 100644 --- a/cmd/test_util_test.go +++ b/cmd/test_util_test.go @@ -135,27 +135,39 @@ func assertNoLocalPortForwarding(t *testing.T, client *ssh.Client) { assert.Equal(t, "ssh: rejected: administratively prohibited (direct-tcpip not allowed)", err.Error()) } -func assertRemotePortForwardingTODO(t *testing.T, client *ssh.Client) { +func assertRemotePortForwarding(t *testing.T, client *ssh.Client) { remotePort := getAvailableTcpPort() - acceptedConnChan := make(chan net.Conn) - var _ = acceptedConnChan ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) - var _ = ln assert.NoError(t, err) + acceptedConnChan := make(chan net.Conn) go func() { - //conn, err := ln.Accept() - //assert.NoError(t, err) - //acceptedConnChan <- conn + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn }() - conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) assert.NoError(t, err) defer conn.Close() - - // FIXME: implement but the following suspends - //acceptedConn := <-acceptedConnChan - //defer acceptedConn.Close() - // TODO: conn <--> acceptedConn communication + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } } func assertNoRemotePortForwarding(t *testing.T, client *ssh.Client) { diff --git a/server.go b/server.go index 77f86ca..2defe06 100644 --- a/server.go +++ b/server.go @@ -312,6 +312,14 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { } replyMsg.Addr = msg.Addr replyMsg.Port = msg.Port + originatorAddr, originatorPortStr, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err == nil { + originatorPort, _ := strconv.Atoi(originatorPortStr) + replyMsg.OriginatorAddr = originatorAddr + replyMsg.OriginatorPort = uint32(originatorPort) + } else { + s.Logger.Error("failed to split remote address", "remote_address", conn.RemoteAddr()) + } go func() { channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&replyMsg))