Merge branch 'release/0.3.0'

This commit is contained in:
Ryo Ota 2023-08-11 12:36:04 +09:00
commit ef2fa8f26d
11 changed files with 936 additions and 150 deletions

View file

@ -3,12 +3,11 @@ name: CI
on: [push]
jobs:
build_multi_platform:
test:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
- name: Set up Go 1.x
uses: actions/setup-go@v4
- uses: actions/setup-go@v4
with:
go-version: "1.20"
- name: Build for multi-platform
@ -29,3 +28,5 @@ jobs:
# Build
CGO_ENABLED=0 go build -o "${BUILD_PATH}/handy-sshd${EXTENSION}" main/main.go
done
- name: Test
run: go test -v ./...

View file

@ -5,6 +5,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
## [Unreleased]
## [0.3.0] - 2023-08-11
### Added
* Support Unix domain socket local port forwarding
* Support Unix domain socket remote port forwarding
### Fixed
* Handle multiple global requests simultaneously
## [0.2.1] - 2023-08-09
### Changed
* Update dependencies
@ -20,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
### Added
* Initial release
[Unreleased]: https://github.com/nwtgck/handy-sshd/compare/v0.2.1...HEAD
[Unreleased]: https://github.com/nwtgck/handy-sshd/compare/v0.3.0...HEAD
[0.3.0]: https://github.com/nwtgck/handy-sshd/compare/v0.2.1...v0.3.0
[0.2.1]: https://github.com/nwtgck/handy-sshd/compare/v0.2.0...v0.2.1
[0.2.0]: https://github.com/nwtgck/handy-sshd/compare/v0.1.0...v0.2.0

View file

@ -41,20 +41,30 @@ handy-sshd --unix-socket /tmp/my-unix-socket --user "john:"
```
## Permissions
**All permissions are allowed when nothing is specified.** There are some permissions.
There are several permissions:
* --allow-direct-streamlocal
* --allow-direct-tcpip
* --allow-execute
* --allow-sftp
* --allow-streamlocal-forward
* --allow-tcpip-forward
Specifying `--allow-direct-tcpip` and `--allow-execute` for example allows only them.
The log shows "allowed: " and "NOT allowed: " permissions as follows.
**All permissions are allowed when nothing is specified.** The log shows "allowed: " and "NOT allowed: " permissions as follows:
```console
$ handy-sshd --user "john:"
2023/08/11 11:40:44 INFO listening on :2222...
2023/08/11 11:40:44 INFO allowed: "tcpip-forward", "direct-tcpip", "execute", "sftp", "streamlocal-forward", "direct-streamlocal"
2023/08/11 11:40:44 INFO NOT allowed: none
```
For example, specifying `--allow-direct-tcpip` and `--allow-execute` allows only them:
```console
$ handy-sshd --user "john:" --allow-direct-tcpip --allow-execute
2023/08/09 20:49:35 INFO listening on :2222...
2023/08/09 20:49:35 INFO allowed: "direct-tcpip", "execute"
2023/08/09 20:49:35 INFO NOT allowed: "tcpip-forward", "sftp"
2023/08/11 11:41:03 INFO listening on :2222...
2023/08/11 11:41:03 INFO allowed: "direct-tcpip", "execute"
2023/08/11 11:41:03 INFO NOT allowed: "tcpip-forward", "sftp", "streamlocal-forward", "direct-streamlocal"
```
## --help
@ -66,15 +76,17 @@ Usage:
handy-sshd [flags]
Flags:
--allow-direct-tcpip client can use local forwarding and SOCKS proxy
--allow-execute client can use shell/interactive shell
--allow-sftp client can use SFTP and SSHFS
--allow-tcpip-forward client can use remote forwarding
-h, --help help for handy-sshd
--host string SSH server host (e.g. 127.0.0.1)
-p, --port uint16 SSH server port (default 2222)
--shell string Shell
--unix-socket string Unix-domain socket
--user stringArray SSH user name (e.g. "john:mypassword")
-v, --version show version
--allow-direct-streamlocal client can use Unix domain socket local forwarding
--allow-direct-tcpip client can use local forwarding and SOCKS proxy
--allow-execute client can use shell/interactive shell
--allow-sftp client can use SFTP and SSHFS
--allow-streamlocal-forward client can use Unix domain socket remote forwarding
--allow-tcpip-forward client can use remote forwarding
-h, --help help for handy-sshd
--host string SSH server host (e.g. 127.0.0.1)
-p, --port uint16 SSH server port (default 2222)
--shell string Shell
--unix-socket string Unix domain socket
--user stringArray SSH user name (e.g. "john:mypassword")
-v, --version show version
```

View file

@ -13,7 +13,7 @@ import (
"strings"
)
var flag struct {
type flagType struct {
//dnsServer string
showsVersion bool
sshHost string
@ -22,20 +22,17 @@ var flag struct {
sshShell string
sshUsers []string
allowTcpipForward bool
allowDirectTcpip bool
allowExecute bool
allowSftp bool
allowTcpipForward bool
allowDirectTcpip bool
allowExecute bool
allowSftp bool
allowStreamlocalForward bool
allowDirectStreamlocal bool
}
var allPermissionFlags = []struct {
type permissionFlagType = struct {
name string
flagPtr *bool
}{
{name: "tcpip-forward", flagPtr: &flag.allowTcpipForward},
{name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip},
{name: "execute", flagPtr: &flag.allowExecute},
{name: "sftp", flagPtr: &flag.allowSftp},
}
type sshUser struct {
@ -45,136 +42,158 @@ type sshUser struct {
func init() {
cobra.OnInitialize()
RootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version")
RootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)")
RootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port")
// NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html)
RootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket")
RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell")
//RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)")
RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`)
// Permission flags
RootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding")
RootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy")
RootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell")
RootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS")
}
var RootCmd = &cobra.Command{
Use: os.Args[0],
Short: "handy-sshd",
Long: "Portable SSH server",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
if flag.showsVersion {
fmt.Println(version.Version)
return nil
}
logger := slog.Default()
func RootCmd() *cobra.Command {
var flag flagType
allPermissionFlags := []permissionFlagType{
{name: "tcpip-forward", flagPtr: &flag.allowTcpipForward},
{name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip},
{name: "execute", flagPtr: &flag.allowExecute},
{name: "sftp", flagPtr: &flag.allowSftp},
{name: "streamlocal-forward", flagPtr: &flag.allowStreamlocalForward},
{name: "direct-streamlocal", flagPtr: &flag.allowDirectStreamlocal},
}
rootCmd := cobra.Command{
Use: os.Args[0],
Short: "handy-sshd",
Long: "Portable SSH server",
SilenceUsage: true,
RunE: func(cmd *cobra.Command, args []string) error {
return rootRunEWithExtra(cmd, args, &flag, allPermissionFlags)
},
}
// Allow all permissions if all permission is not set
{
allPermissionFalse := true
rootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version")
rootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)")
rootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port")
// NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html)
rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix domain socket")
rootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell")
//rootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)")
rootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`)
// Permission flags
rootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding")
rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy")
rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell")
rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS")
rootCmd.PersistentFlags().BoolVarP(&flag.allowStreamlocalForward, "allow-streamlocal-forward", "", false, "client can use Unix domain socket remote forwarding")
rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectStreamlocal, "allow-direct-streamlocal", "", false, "client can use Unix domain socket local forwarding")
return &rootCmd
}
func rootRunEWithExtra(cmd *cobra.Command, args []string, flag *flagType, allPermissionFlags []permissionFlagType) error {
if flag.showsVersion {
fmt.Fprintln(cmd.OutOrStdout(), version.Version)
return nil
}
logger := slog.Default()
// Allow all permissions if all permission is not set
{
allPermissionFalse := true
for _, permissionFlag := range allPermissionFlags {
allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr
}
if allPermissionFalse {
for _, permissionFlag := range allPermissionFlags {
allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr
}
if allPermissionFalse {
for _, permissionFlag := range allPermissionFlags {
*permissionFlag.flagPtr = true
}
*permissionFlag.flagPtr = true
}
}
}
sshServer := &handy_sshd.Server{
Logger: logger,
AllowTcpipForward: flag.allowTcpipForward,
AllowDirectTcpip: flag.allowDirectTcpip,
AllowExecute: flag.allowExecute,
AllowSftp: flag.allowSftp,
sshServer := &handy_sshd.Server{
Logger: logger,
AllowTcpipForward: flag.allowTcpipForward,
AllowDirectTcpip: flag.allowDirectTcpip,
AllowExecute: flag.allowExecute,
AllowSftp: flag.allowSftp,
AllowStreamlocalForward: flag.allowStreamlocalForward,
AllowDirectStreamlocal: flag.allowDirectStreamlocal,
}
var sshUsers []sshUser
for _, u := range flag.sshUsers {
splits := strings.SplitN(u, ":", 2)
if len(splits) != 2 {
return fmt.Errorf("invalid user format: %s", u)
}
var sshUsers []sshUser
for _, u := range flag.sshUsers {
splits := strings.SplitN(u, ":", 2)
if len(splits) != 2 {
return fmt.Errorf("invalid user format: %s", u)
}
sshUsers = append(sshUsers, sshUser{name: splits[0], password: splits[1]})
}
if len(sshUsers) == 0 {
return fmt.Errorf(`No user specified
sshUsers = append(sshUsers, sshUser{name: splits[0], password: splits[1]})
}
if len(sshUsers) == 0 {
return fmt.Errorf(`No user specified
e.g. --user "john:mypassword"
e.g. --user "john:"`)
}
// (base: https://gist.github.com/jpillora/b480fde82bff51a06238)
sshConfig := &ssh.ServerConfig{
//Define a function to run when a client attempts a password login
PasswordCallback: func(metadata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
for _, user := range sshUsers {
// No auth required
if user.name == metadata.User() && user.password == string(pass) {
return nil, nil
}
}
// (base: https://gist.github.com/jpillora/b480fde82bff51a06238)
sshConfig := &ssh.ServerConfig{
//Define a function to run when a client attempts a password login
PasswordCallback: func(metadata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
for _, user := range sshUsers {
// No auth required
if user.name == metadata.User() && user.password == string(pass) {
return nil, nil
}
return nil, fmt.Errorf("password rejected for %q", metadata.User())
},
NoClientAuth: true,
NoClientAuthCallback: func(metadata ssh.ConnMetadata) (*ssh.Permissions, error) {
for _, user := range sshUsers {
// No auth required
if user.name == metadata.User() && user.password == "" {
return nil, nil
}
}
return nil, fmt.Errorf("password rejected for %q", metadata.User())
},
NoClientAuth: true,
NoClientAuthCallback: func(metadata ssh.ConnMetadata) (*ssh.Permissions, error) {
for _, user := range sshUsers {
// No auth required
if user.name == metadata.User() && user.password == "" {
return nil, nil
}
return nil, fmt.Errorf("%s auth required", metadata.User())
},
}
// TODO: specify priv_key by flags
pri, err := ssh.ParsePrivateKey([]byte(defaultHostKeyPem))
}
return nil, fmt.Errorf("%s auth required", metadata.User())
},
}
// TODO: specify priv_key by flags
pri, err := ssh.ParsePrivateKey([]byte(defaultHostKeyPem))
if err != nil {
return err
}
sshConfig.AddHostKey(pri)
var ln net.Listener
if flag.sshUnixSocket == "" {
address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort)))
ln, err = net.Listen("tcp", address)
if err != nil {
return err
}
sshConfig.AddHostKey(pri)
var ln net.Listener
if flag.sshUnixSocket == "" {
address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort)))
ln, err = net.Listen("tcp", address)
if err != nil {
return err
}
logger.Info(fmt.Sprintf("listening on %s...", address))
} else {
ln, err = net.Listen("unix", flag.sshUnixSocket)
if err != nil {
return err
}
logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket))
logger.Info(fmt.Sprintf("listening on %s...", address))
} else {
ln, err = net.Listen("unix", flag.sshUnixSocket)
if err != nil {
return err
}
defer ln.Close()
logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket))
}
defer ln.Close()
showPermissions(logger)
showPermissions(logger, allPermissionFlags)
for {
conn, err := ln.Accept()
if err != nil {
logger.Error("failed to accept TCP connection", "err", err)
continue
}
sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig)
if err != nil {
logger.Info("failed to handshake", "err", err)
conn.Close()
continue
}
logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion()))
go sshServer.HandleGlobalRequests(sshConn, reqs)
go sshServer.HandleChannels(flag.sshShell, chans)
for {
conn, err := ln.Accept()
if err != nil {
logger.Error("failed to accept TCP connection", "err", err)
continue
}
},
sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig)
if err != nil {
logger.Info("failed to handshake", "err", err)
conn.Close()
continue
}
logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion()))
go sshServer.HandleGlobalRequests(sshConn, reqs)
go sshServer.HandleChannels(flag.sshShell, chans)
}
}
func showPermissions(logger *slog.Logger) {
func showPermissions(logger *slog.Logger, allPermissionFlags []permissionFlagType) {
var allowedList []string
var notAllowedList []string
for _, permissionFlag := range allPermissionFlags {

327
cmd/root_test.go Normal file
View file

@ -0,0 +1,327 @@
package cmd
import (
"bytes"
"context"
"github.com/nwtgck/handy-sshd/version"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"net"
"strconv"
"testing"
)
func TestVersion(t *testing.T) {
rootCmd := RootCmd()
rootCmd.SetArgs([]string{"--version"})
var stdoutBuf bytes.Buffer
rootCmd.SetOut(&stdoutBuf)
assert.NoError(t, rootCmd.Execute())
assert.Equal(t, version.Version+"\n", stdoutBuf.String())
}
func TestZeroUsers(t *testing.T) {
rootCmd := RootCmd()
rootCmd.SetArgs([]string{})
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
assert.Error(t, rootCmd.Execute())
assert.Equal(t, `Error: No user specified
e.g. --user "john:mypassword"
e.g. --user "john:"
`, stderrBuf.String())
}
func TestAllPermissionsAllowed(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertRemotePortForwarding(t, client)
assertLocalPortForwarding(t, client)
assertExec(t, client)
assertPtyTerminal(t, client)
assertSftp(t, client)
assertUnixRemotePortForwarding(t, client)
assertUnixLocalPortForwarding(t, client)
}
func TestEmptyPassword(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
}
func TestMultipleUsers(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword1", "--user", "alex:mypassword2"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
for _, user := range []struct {
name string
password string
}{{name: "john", password: "mypassword1"}, {name: "alex", password: "mypassword2"}} {
sshClientConfig := &ssh.ClientConfig{
User: user.name,
Auth: []ssh.AuthMethod{ssh.Password(user.password)},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
}
}
func TestWrongPassword(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mywrongpassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
_, err := ssh.Dial("tcp", address, sshClientConfig)
assert.Error(t, err)
assert.Equal(t, `ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain`, err.Error())
}
func TestAllowExecute(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-execute"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertNoRemotePortForwarding(t, client)
assertNoLocalPortForwarding(t, client)
assertExec(t, client)
assertPtyTerminal(t, client)
assertNoSftp(t, client)
assertNoUnixRemotePortForwarding(t, client)
assertNoUnixLocalPortForwarding(t, client)
}
func TestAllowTcpipForward(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-tcpip-forward"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertRemotePortForwarding(t, client)
assertNoLocalPortForwarding(t, client)
assertNoExec(t, client)
assertNoPtyTerminal(t, client)
assertNoSftp(t, client)
assertNoUnixRemotePortForwarding(t, client)
assertNoUnixLocalPortForwarding(t, client)
}
func TestAllowStreamlocalForward(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-streamlocal-forward"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertNoRemotePortForwarding(t, client)
assertNoLocalPortForwarding(t, client)
assertNoExec(t, client)
assertNoPtyTerminal(t, client)
assertNoSftp(t, client)
assertUnixRemotePortForwarding(t, client)
assertNoUnixLocalPortForwarding(t, client)
}
func TestAllowDirectTcpip(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-direct-tcpip"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertNoRemotePortForwarding(t, client)
assertLocalPortForwarding(t, client)
assertNoExec(t, client)
assertNoPtyTerminal(t, client)
assertNoSftp(t, client)
assertNoUnixRemotePortForwarding(t, client)
assertNoUnixLocalPortForwarding(t, client)
}
func TestAllowDirectStreamlocal(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-direct-streamlocal"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertNoRemotePortForwarding(t, client)
assertNoLocalPortForwarding(t, client)
assertNoExec(t, client)
assertNoPtyTerminal(t, client)
assertNoSftp(t, client)
assertNoUnixRemotePortForwarding(t, client)
assertUnixLocalPortForwarding(t, client)
}
func TestAllowSftp(t *testing.T) {
rootCmd := RootCmd()
port := getAvailableTcpPort()
rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-sftp"})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
var stderrBuf bytes.Buffer
rootCmd.SetErr(&stderrBuf)
rootCmd.ExecuteContext(ctx)
}()
waitTCPServer(port)
sshClientConfig := &ssh.ClientConfig{
User: "john",
Auth: []ssh.AuthMethod{ssh.Password("mypassword")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port))
client, err := ssh.Dial("tcp", address, sshClientConfig)
assert.NoError(t, err)
defer client.Close()
assert.NoError(t, err)
assertNoRemotePortForwarding(t, client)
assertNoLocalPortForwarding(t, client)
assertNoExec(t, client)
assertNoPtyTerminal(t, client)
assertNoUnixRemotePortForwarding(t, client)
assertSftp(t, client)
}

283
cmd/test_util_test.go Normal file
View file

@ -0,0 +1,283 @@
package cmd
import (
"bytes"
"github.com/google/uuid"
"github.com/pkg/sftp"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"io"
"net"
"os"
"os/exec"
"path"
"strconv"
"testing"
"time"
)
func getAvailableTcpPort() int {
ln, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
defer ln.Close()
return ln.Addr().(*net.TCPAddr).Port
}
func waitTCPServer(port int) {
for {
conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)))
if err == nil {
conn.Close()
break
}
}
}
func assertExec(t *testing.T, client *ssh.Client) {
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()
whoamiBytes, err := session.Output("whoami")
assert.NoError(t, err)
expectedWhoamiBytes, err := exec.Command("whoami").Output()
assert.NoError(t, err)
assert.Equal(t, string(expectedWhoamiBytes), string(whoamiBytes))
}
func assertNoExec(t *testing.T, client *ssh.Client) {
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()
_, err = session.Output("whoami")
assert.Error(t, err)
assert.Equal(t, "ssh: command whoami failed", err.Error())
}
func assertPtyTerminal(t *testing.T, client *ssh.Client) {
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()
err = session.RequestPty("xterm", 100, 200, ssh.TerminalModes{})
assert.NoError(t, err)
stdin, err := session.StdinPipe()
assert.NoError(t, err)
_, err = stdin.Write([]byte("echo helloworldviapty\r"))
assert.NoError(t, err)
stdout, err := session.StdoutPipe()
assert.NoError(t, err)
stdoutBytesChan := make(chan []byte)
go func() {
var buff bytes.Buffer
_, err := io.Copy(&buff, stdout)
assert.NoError(t, err)
stdoutBytesChan <- buff.Bytes()
}()
err = session.Shell()
assert.NoError(t, err)
time.Sleep(1 * time.Second)
session.Close()
stdoutBytes := <-stdoutBytesChan
assert.Contains(t, string(stdoutBytes), "helloworldviapty")
}
func assertNoPtyTerminal(t *testing.T, client *ssh.Client) {
session, err := client.NewSession()
assert.NoError(t, err)
defer session.Close()
err = session.RequestPty("xterm", 100, 200, ssh.TerminalModes{})
assert.Error(t, err)
assert.Equal(t, "ssh: pty-req failed", err.Error())
}
func assertLocalPortForwarding(t *testing.T, client *ssh.Client) {
var remoteTcpPort int
acceptedConnChan := make(chan net.Conn)
{
ln, err := net.Listen("tcp", ":0")
assert.NoError(t, err)
remoteTcpPort = ln.Addr().(*net.TCPAddr).Port
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
acceptedConnChan <- conn
}()
}
raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort}
conn, err := client.DialTCP("tcp", nil, raddr)
assert.NoError(t, err)
defer conn.Close()
acceptedConn := <-acceptedConnChan
defer acceptedConn.Close()
{
localToRemote := [3]byte{1, 2, 3}
_, err = conn.Write(localToRemote[:])
assert.NoError(t, err)
var buf [len(localToRemote)]byte
_, err = io.ReadFull(acceptedConn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, localToRemote)
}
{
remoteToLocal := [4]byte{10, 20, 30, 40}
_, err = acceptedConn.Write(remoteToLocal[:])
assert.NoError(t, err)
var buf [len(remoteToLocal)]byte
_, err = io.ReadFull(conn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, remoteToLocal)
}
}
func assertNoLocalPortForwarding(t *testing.T, client *ssh.Client) {
raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}
_, err := client.DialTCP("tcp", nil, raddr)
assert.Error(t, err)
assert.Equal(t, "ssh: rejected: administratively prohibited (direct-tcpip not allowed)", err.Error())
}
func assertUnixLocalPortForwarding(t *testing.T, client *ssh.Client) {
remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String())
acceptedConnChan := make(chan net.Conn)
{
ln, err := net.Listen("unix", remoteUnixSocket)
assert.NoError(t, err)
defer os.Remove(remoteUnixSocket)
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
acceptedConnChan <- conn
}()
}
conn, err := client.Dial("unix", remoteUnixSocket)
assert.NoError(t, err)
defer conn.Close()
acceptedConn := <-acceptedConnChan
defer acceptedConn.Close()
{
localToRemote := [3]byte{1, 2, 3}
_, err = conn.Write(localToRemote[:])
assert.NoError(t, err)
var buf [len(localToRemote)]byte
_, err = io.ReadFull(acceptedConn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, localToRemote)
}
{
remoteToLocal := [4]byte{10, 20, 30, 40}
_, err = acceptedConn.Write(remoteToLocal[:])
assert.NoError(t, err)
var buf [len(remoteToLocal)]byte
_, err = io.ReadFull(conn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, remoteToLocal)
}
}
func assertNoUnixLocalPortForwarding(t *testing.T, client *ssh.Client) {
remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String())
_, err := client.Dial("unix", remoteUnixSocket)
assert.Error(t, err)
assert.Equal(t, "ssh: rejected: administratively prohibited (direct-streamlocal (Unix domain socket) not allowed)", err.Error())
}
func assertRemotePortForwarding(t *testing.T, client *ssh.Client) {
remotePort := getAvailableTcpPort()
ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort)))
assert.NoError(t, err)
defer ln.Close()
acceptedConnChan := make(chan net.Conn)
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
acceptedConnChan <- conn
}()
conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort)))
assert.NoError(t, err)
defer conn.Close()
acceptedConn := <-acceptedConnChan
defer acceptedConn.Close()
{
localToRemote := [3]byte{1, 2, 3}
_, err = conn.Write(localToRemote[:])
assert.NoError(t, err)
var buf [len(localToRemote)]byte
_, err = io.ReadFull(acceptedConn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, localToRemote)
}
{
remoteToLocal := [4]byte{10, 20, 30, 40}
_, err = acceptedConn.Write(remoteToLocal[:])
assert.NoError(t, err)
var buf [len(remoteToLocal)]byte
_, err = io.ReadFull(conn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, remoteToLocal)
}
}
func assertNoRemotePortForwarding(t *testing.T, client *ssh.Client) {
_, err := client.Listen("tcp", "127.0.0.1:5678")
assert.Error(t, err)
assert.Equal(t, "ssh: tcpip-forward request denied by peer", err.Error())
}
func assertUnixRemotePortForwarding(t *testing.T, client *ssh.Client) {
remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String())
ln, err := client.ListenUnix(remoteUnixSocket)
assert.NoError(t, err)
defer os.Remove(remoteUnixSocket)
defer ln.Close()
acceptedConnChan := make(chan net.Conn)
go func() {
conn, err := ln.Accept()
assert.NoError(t, err)
acceptedConnChan <- conn
}()
conn, err := net.Dial("unix", remoteUnixSocket)
assert.NoError(t, err)
defer conn.Close()
acceptedConn := <-acceptedConnChan
defer acceptedConn.Close()
{
localToRemote := [3]byte{1, 2, 3}
_, err = conn.Write(localToRemote[:])
assert.NoError(t, err)
var buf [len(localToRemote)]byte
_, err = io.ReadFull(acceptedConn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, localToRemote)
}
{
remoteToLocal := [4]byte{10, 20, 30, 40}
_, err = acceptedConn.Write(remoteToLocal[:])
assert.NoError(t, err)
var buf [len(remoteToLocal)]byte
_, err = io.ReadFull(conn, buf[:])
assert.NoError(t, err)
assert.Equal(t, buf, remoteToLocal)
}
}
func assertNoUnixRemotePortForwarding(t *testing.T, client *ssh.Client) {
remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String())
_, err := client.ListenUnix(remoteUnixSocket)
assert.Error(t, err)
assert.Equal(t, "ssh: streamlocal-forward@openssh.com request denied by peer", err.Error())
}
func assertSftp(t *testing.T, client *ssh.Client) {
sftpClient, err := sftp.NewClient(client)
assert.NoError(t, err)
_, err = sftpClient.Getwd()
assert.NoError(t, err)
}
func assertNoSftp(t *testing.T, client *ssh.Client) {
_, err := sftp.NewClient(client)
assert.Error(t, err)
assert.Equal(t, "ssh: subsystem request failed", err.Error())
}

6
go.mod
View file

@ -4,18 +4,22 @@ go 1.20
require (
github.com/creack/pty v1.1.18
github.com/google/uuid v1.3.0
github.com/mattn/go-shellwords v1.0.12
github.com/pkg/errors v0.9.1
github.com/pkg/sftp v1.13.5
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.8.4
golang.org/x/crypto v0.12.0
golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/testify v1.8.4 // indirect
golang.org/x/sys v0.11.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

4
go.sum
View file

@ -3,6 +3,9 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
@ -40,6 +43,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn
golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View file

@ -6,7 +6,7 @@ import (
)
func main() {
if err := cmd.RootCmd.Execute(); err != nil {
if err := cmd.RootCmd().Execute(); err != nil {
os.Exit(-1)
}
}

139
server.go
View file

@ -30,10 +30,12 @@ type Server struct {
Logger *slog.Logger
// Permissions
AllowTcpipForward bool
AllowDirectTcpip bool
AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution.
AllowSftp bool
AllowTcpipForward bool
AllowDirectTcpip bool
AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution.
AllowSftp bool
AllowStreamlocalForward bool
AllowDirectStreamlocal bool
// TODO: DNS server ?
}
@ -59,6 +61,12 @@ func (s *Server) handleChannel(shell string, newChannel ssh.NewChannel) {
break
}
s.handleDirectTcpip(newChannel)
case "direct-streamlocal@openssh.com":
if !s.AllowDirectStreamlocal {
newChannel.Reject(ssh.Prohibited, "direct-streamlocal (Unix domain socket) not allowed")
break
}
s.handleDirectStreamlocal(newChannel)
default:
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
}
@ -114,6 +122,8 @@ func (s *Server) handleSession(shell string, newChannel ssh.NewChannel) {
}
case "subsystem":
s.handleSessionSubSystem(req, connection)
default:
s.Logger.Info("unknown request", "req_type", req.Type)
}
}
}
@ -197,7 +207,7 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) {
SourcePort uint32
}
if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil {
s.Logger.Info("failed to parse message", "err", err)
s.Logger.Info("failed to parse direct-tcpip message", "err", err)
return
}
channel, reqs, err := newChannel.Accept()
@ -227,6 +237,44 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) {
return
}
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L52
func (s *Server) handleDirectStreamlocal(newChannel ssh.NewChannel) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L237
var msg struct {
SocketPath string
Reserved0 string
Reserved1 uint32
}
if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil {
s.Logger.Info("failed to parse direct-streamlocal message", "err", err)
return
}
channel, reqs, err := newChannel.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
go ssh.DiscardRequests(reqs)
conn, err := net.Dial("unix", msg.SocketPath)
if err != nil {
s.Logger.Info("failed to dial", "err", err)
channel.Close()
return
}
var closeOnce sync.Once
closer := func() {
channel.Close()
conn.Close()
}
go func() {
io.Copy(channel, conn)
closeOnce.Do(closer)
}()
io.Copy(conn, channel)
closeOnce.Do(closer)
return
}
// =======================
// parseDims extracts terminal dimensions (width x height) from the provided buffer.
@ -266,7 +314,20 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.
req.Reply(false, nil)
break
}
s.handleTcpipForward(sshConn, req)
go func() {
s.handleTcpipForward(sshConn, req)
}()
case "streamlocal-forward@openssh.com":
if !s.AllowStreamlocalForward {
s.Logger.Info("streamlocal-forward not allowed")
req.Reply(false, nil)
break
}
go func() {
s.handleStreamlocalForward(sshConn, req)
}()
// TODO: support cancel-tcpip-forward
// TODO: support cancel-streamlocal-forward@openssh.com
default:
// discard
if req.WantReply {
@ -292,6 +353,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
req.Reply(false, nil)
return
}
req.Reply(true, nil)
go func() {
sshConn.Wait()
ln.Close()
@ -300,6 +362,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
for {
conn, err := ln.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
var replyMsg struct {
@ -310,6 +373,14 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
}
replyMsg.Addr = msg.Addr
replyMsg.Port = msg.Port
originatorAddr, originatorPortStr, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
originatorPort, _ := strconv.Atoi(originatorPortStr)
replyMsg.OriginatorAddr = originatorAddr
replyMsg.OriginatorPort = uint32(originatorPort)
} else {
s.Logger.Error("failed to split remote address", "remote_address", conn.RemoteAddr())
}
go func() {
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&replyMsg))
@ -332,3 +403,59 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
}()
}
}
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34
func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272
var msg struct {
SocketPath string
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
ln, err := net.Listen("unix", msg.SocketPath)
if err != nil {
req.Reply(false, nil)
return
}
req.Reply(true, nil)
go func() {
sshConn.Wait()
ln.Close()
s.Logger.Info("connection closed", "address", ln.Addr().String())
}()
for {
conn, err := ln.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L255
var replyMsg struct {
SocketPath string
Reserved string
}
replyMsg.SocketPath = msg.SocketPath
go func() {
channel, reqs, err := sshConn.OpenChannel("forwarded-streamlocal@openssh.com", ssh.Marshal(&replyMsg))
if err != nil {
req.Reply(false, nil)
conn.Close()
return
}
go ssh.DiscardRequests(reqs)
go func() {
io.Copy(channel, conn)
conn.Close()
channel.Close()
}()
go func() {
io.Copy(conn, channel)
conn.Close()
channel.Close()
}()
}()
}
}

View file

@ -1,3 +1,3 @@
package version
const Version = "0.2.1"
const Version = "0.3.0"