diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 883a903..521729e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,12 +3,11 @@ name: CI on: [push] jobs: - build_multi_platform: + test: runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 - - name: Set up Go 1.x - uses: actions/setup-go@v4 + - uses: actions/setup-go@v4 with: go-version: "1.20" - name: Build for multi-platform @@ -29,3 +28,5 @@ jobs: # Build CGO_ENABLED=0 go build -o "${BUILD_PATH}/handy-sshd${EXTENSION}" main/main.go done + - name: Test + run: go test -v ./... diff --git a/CHANGELOG.md b/CHANGELOG.md index 70c89c8..fb65850 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) ## [Unreleased] +## [0.3.0] - 2023-08-11 +### Added +* Support Unix domain socket local port forwarding +* Support Unix domain socket remote port forwarding + +### Fixed +* Handle multiple global requests simultaneously + ## [0.2.1] - 2023-08-09 ### Changed * Update dependencies @@ -20,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) ### Added * Initial release -[Unreleased]: https://github.com/nwtgck/handy-sshd/compare/v0.2.1...HEAD +[Unreleased]: https://github.com/nwtgck/handy-sshd/compare/v0.3.0...HEAD +[0.3.0]: https://github.com/nwtgck/handy-sshd/compare/v0.2.1...v0.3.0 [0.2.1]: https://github.com/nwtgck/handy-sshd/compare/v0.2.0...v0.2.1 [0.2.0]: https://github.com/nwtgck/handy-sshd/compare/v0.1.0...v0.2.0 diff --git a/README.md b/README.md index 16017d8..a648d72 100644 --- a/README.md +++ b/README.md @@ -41,20 +41,30 @@ handy-sshd --unix-socket /tmp/my-unix-socket --user "john:" ``` ## Permissions -**All permissions are allowed when nothing is specified.** There are some permissions. +There are several permissions: +* --allow-direct-streamlocal * --allow-direct-tcpip * --allow-execute * --allow-sftp +* --allow-streamlocal-forward * --allow-tcpip-forward -Specifying `--allow-direct-tcpip` and `--allow-execute` for example allows only them. -The log shows "allowed: " and "NOT allowed: " permissions as follows. +**All permissions are allowed when nothing is specified.** The log shows "allowed: " and "NOT allowed: " permissions as follows: + +```console +$ handy-sshd --user "john:" +2023/08/11 11:40:44 INFO listening on :2222... +2023/08/11 11:40:44 INFO allowed: "tcpip-forward", "direct-tcpip", "execute", "sftp", "streamlocal-forward", "direct-streamlocal" +2023/08/11 11:40:44 INFO NOT allowed: none +``` + +For example, specifying `--allow-direct-tcpip` and `--allow-execute` allows only them: ```console $ handy-sshd --user "john:" --allow-direct-tcpip --allow-execute -2023/08/09 20:49:35 INFO listening on :2222... -2023/08/09 20:49:35 INFO allowed: "direct-tcpip", "execute" -2023/08/09 20:49:35 INFO NOT allowed: "tcpip-forward", "sftp" +2023/08/11 11:41:03 INFO listening on :2222... +2023/08/11 11:41:03 INFO allowed: "direct-tcpip", "execute" +2023/08/11 11:41:03 INFO NOT allowed: "tcpip-forward", "sftp", "streamlocal-forward", "direct-streamlocal" ``` ## --help @@ -66,15 +76,17 @@ Usage: handy-sshd [flags] Flags: - --allow-direct-tcpip client can use local forwarding and SOCKS proxy - --allow-execute client can use shell/interactive shell - --allow-sftp client can use SFTP and SSHFS - --allow-tcpip-forward client can use remote forwarding - -h, --help help for handy-sshd - --host string SSH server host (e.g. 127.0.0.1) - -p, --port uint16 SSH server port (default 2222) - --shell string Shell - --unix-socket string Unix-domain socket - --user stringArray SSH user name (e.g. "john:mypassword") - -v, --version show version + --allow-direct-streamlocal client can use Unix domain socket local forwarding + --allow-direct-tcpip client can use local forwarding and SOCKS proxy + --allow-execute client can use shell/interactive shell + --allow-sftp client can use SFTP and SSHFS + --allow-streamlocal-forward client can use Unix domain socket remote forwarding + --allow-tcpip-forward client can use remote forwarding + -h, --help help for handy-sshd + --host string SSH server host (e.g. 127.0.0.1) + -p, --port uint16 SSH server port (default 2222) + --shell string Shell + --unix-socket string Unix domain socket + --user stringArray SSH user name (e.g. "john:mypassword") + -v, --version show version ``` diff --git a/cmd/root.go b/cmd/root.go index 9581e71..a2d6b66 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -13,7 +13,7 @@ import ( "strings" ) -var flag struct { +type flagType struct { //dnsServer string showsVersion bool sshHost string @@ -22,20 +22,17 @@ var flag struct { sshShell string sshUsers []string - allowTcpipForward bool - allowDirectTcpip bool - allowExecute bool - allowSftp bool + allowTcpipForward bool + allowDirectTcpip bool + allowExecute bool + allowSftp bool + allowStreamlocalForward bool + allowDirectStreamlocal bool } -var allPermissionFlags = []struct { +type permissionFlagType = struct { name string flagPtr *bool -}{ - {name: "tcpip-forward", flagPtr: &flag.allowTcpipForward}, - {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, - {name: "execute", flagPtr: &flag.allowExecute}, - {name: "sftp", flagPtr: &flag.allowSftp}, } type sshUser struct { @@ -45,136 +42,158 @@ type sshUser struct { func init() { cobra.OnInitialize() - RootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version") - RootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") - RootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") - // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) - RootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket") - RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") - //RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") - RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) - - // Permission flags - RootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding") - RootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") - RootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") - RootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") } -var RootCmd = &cobra.Command{ - Use: os.Args[0], - Short: "handy-sshd", - Long: "Portable SSH server", - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - if flag.showsVersion { - fmt.Println(version.Version) - return nil - } - logger := slog.Default() +func RootCmd() *cobra.Command { + var flag flagType + allPermissionFlags := []permissionFlagType{ + {name: "tcpip-forward", flagPtr: &flag.allowTcpipForward}, + {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, + {name: "execute", flagPtr: &flag.allowExecute}, + {name: "sftp", flagPtr: &flag.allowSftp}, + {name: "streamlocal-forward", flagPtr: &flag.allowStreamlocalForward}, + {name: "direct-streamlocal", flagPtr: &flag.allowDirectStreamlocal}, + } + rootCmd := cobra.Command{ + Use: os.Args[0], + Short: "handy-sshd", + Long: "Portable SSH server", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return rootRunEWithExtra(cmd, args, &flag, allPermissionFlags) + }, + } - // Allow all permissions if all permission is not set - { - allPermissionFalse := true + rootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version") + rootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") + rootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") + // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) + rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix domain socket") + rootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") + //rootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") + rootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) + + // Permission flags + rootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding") + rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") + rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") + rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") + rootCmd.PersistentFlags().BoolVarP(&flag.allowStreamlocalForward, "allow-streamlocal-forward", "", false, "client can use Unix domain socket remote forwarding") + rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectStreamlocal, "allow-direct-streamlocal", "", false, "client can use Unix domain socket local forwarding") + + return &rootCmd +} + +func rootRunEWithExtra(cmd *cobra.Command, args []string, flag *flagType, allPermissionFlags []permissionFlagType) error { + if flag.showsVersion { + fmt.Fprintln(cmd.OutOrStdout(), version.Version) + return nil + } + logger := slog.Default() + + // Allow all permissions if all permission is not set + { + allPermissionFalse := true + for _, permissionFlag := range allPermissionFlags { + allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr + } + if allPermissionFalse { for _, permissionFlag := range allPermissionFlags { - allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr - } - if allPermissionFalse { - for _, permissionFlag := range allPermissionFlags { - *permissionFlag.flagPtr = true - } + *permissionFlag.flagPtr = true } } + } - sshServer := &handy_sshd.Server{ - Logger: logger, - AllowTcpipForward: flag.allowTcpipForward, - AllowDirectTcpip: flag.allowDirectTcpip, - AllowExecute: flag.allowExecute, - AllowSftp: flag.allowSftp, + sshServer := &handy_sshd.Server{ + Logger: logger, + AllowTcpipForward: flag.allowTcpipForward, + AllowDirectTcpip: flag.allowDirectTcpip, + AllowExecute: flag.allowExecute, + AllowSftp: flag.allowSftp, + AllowStreamlocalForward: flag.allowStreamlocalForward, + AllowDirectStreamlocal: flag.allowDirectStreamlocal, + } + var sshUsers []sshUser + for _, u := range flag.sshUsers { + splits := strings.SplitN(u, ":", 2) + if len(splits) != 2 { + return fmt.Errorf("invalid user format: %s", u) } - var sshUsers []sshUser - for _, u := range flag.sshUsers { - splits := strings.SplitN(u, ":", 2) - if len(splits) != 2 { - return fmt.Errorf("invalid user format: %s", u) - } - sshUsers = append(sshUsers, sshUser{name: splits[0], password: splits[1]}) - } - if len(sshUsers) == 0 { - return fmt.Errorf(`No user specified + sshUsers = append(sshUsers, sshUser{name: splits[0], password: splits[1]}) + } + if len(sshUsers) == 0 { + return fmt.Errorf(`No user specified e.g. --user "john:mypassword" e.g. --user "john:"`) - } - // (base: https://gist.github.com/jpillora/b480fde82bff51a06238) - sshConfig := &ssh.ServerConfig{ - //Define a function to run when a client attempts a password login - PasswordCallback: func(metadata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { - for _, user := range sshUsers { - // No auth required - if user.name == metadata.User() && user.password == string(pass) { - return nil, nil - } + } + // (base: https://gist.github.com/jpillora/b480fde82bff51a06238) + sshConfig := &ssh.ServerConfig{ + //Define a function to run when a client attempts a password login + PasswordCallback: func(metadata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + for _, user := range sshUsers { + // No auth required + if user.name == metadata.User() && user.password == string(pass) { + return nil, nil } - return nil, fmt.Errorf("password rejected for %q", metadata.User()) - }, - NoClientAuth: true, - NoClientAuthCallback: func(metadata ssh.ConnMetadata) (*ssh.Permissions, error) { - for _, user := range sshUsers { - // No auth required - if user.name == metadata.User() && user.password == "" { - return nil, nil - } + } + return nil, fmt.Errorf("password rejected for %q", metadata.User()) + }, + NoClientAuth: true, + NoClientAuthCallback: func(metadata ssh.ConnMetadata) (*ssh.Permissions, error) { + for _, user := range sshUsers { + // No auth required + if user.name == metadata.User() && user.password == "" { + return nil, nil } - return nil, fmt.Errorf("%s auth required", metadata.User()) - }, - } - // TODO: specify priv_key by flags - pri, err := ssh.ParsePrivateKey([]byte(defaultHostKeyPem)) + } + return nil, fmt.Errorf("%s auth required", metadata.User()) + }, + } + // TODO: specify priv_key by flags + pri, err := ssh.ParsePrivateKey([]byte(defaultHostKeyPem)) + if err != nil { + return err + } + sshConfig.AddHostKey(pri) + + var ln net.Listener + if flag.sshUnixSocket == "" { + address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) + ln, err = net.Listen("tcp", address) if err != nil { return err } - sshConfig.AddHostKey(pri) - - var ln net.Listener - if flag.sshUnixSocket == "" { - address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) - ln, err = net.Listen("tcp", address) - if err != nil { - return err - } - logger.Info(fmt.Sprintf("listening on %s...", address)) - } else { - ln, err = net.Listen("unix", flag.sshUnixSocket) - if err != nil { - return err - } - logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket)) + logger.Info(fmt.Sprintf("listening on %s...", address)) + } else { + ln, err = net.Listen("unix", flag.sshUnixSocket) + if err != nil { + return err } - defer ln.Close() + logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket)) + } + defer ln.Close() - showPermissions(logger) + showPermissions(logger, allPermissionFlags) - for { - conn, err := ln.Accept() - if err != nil { - logger.Error("failed to accept TCP connection", "err", err) - continue - } - sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig) - if err != nil { - logger.Info("failed to handshake", "err", err) - conn.Close() - continue - } - logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion())) - go sshServer.HandleGlobalRequests(sshConn, reqs) - go sshServer.HandleChannels(flag.sshShell, chans) + for { + conn, err := ln.Accept() + if err != nil { + logger.Error("failed to accept TCP connection", "err", err) + continue } - }, + sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig) + if err != nil { + logger.Info("failed to handshake", "err", err) + conn.Close() + continue + } + logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion())) + go sshServer.HandleGlobalRequests(sshConn, reqs) + go sshServer.HandleChannels(flag.sshShell, chans) + } } -func showPermissions(logger *slog.Logger) { +func showPermissions(logger *slog.Logger, allPermissionFlags []permissionFlagType) { var allowedList []string var notAllowedList []string for _, permissionFlag := range allPermissionFlags { diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 0000000..1960213 --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,327 @@ +package cmd + +import ( + "bytes" + "context" + "github.com/nwtgck/handy-sshd/version" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" + "net" + "strconv" + "testing" +) + +func TestVersion(t *testing.T) { + rootCmd := RootCmd() + rootCmd.SetArgs([]string{"--version"}) + var stdoutBuf bytes.Buffer + rootCmd.SetOut(&stdoutBuf) + assert.NoError(t, rootCmd.Execute()) + assert.Equal(t, version.Version+"\n", stdoutBuf.String()) +} + +func TestZeroUsers(t *testing.T) { + rootCmd := RootCmd() + rootCmd.SetArgs([]string{}) + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + assert.Error(t, rootCmd.Execute()) + assert.Equal(t, `Error: No user specified +e.g. --user "john:mypassword" +e.g. --user "john:" +`, stderrBuf.String()) +} + +func TestAllPermissionsAllowed(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertRemotePortForwarding(t, client) + assertLocalPortForwarding(t, client) + assertExec(t, client) + assertPtyTerminal(t, client) + assertSftp(t, client) + assertUnixRemotePortForwarding(t, client) + assertUnixLocalPortForwarding(t, client) +} + +func TestEmptyPassword(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() +} + +func TestMultipleUsers(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword1", "--user", "alex:mypassword2"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + + for _, user := range []struct { + name string + password string + }{{name: "john", password: "mypassword1"}, {name: "alex", password: "mypassword2"}} { + sshClientConfig := &ssh.ClientConfig{ + User: user.name, + Auth: []ssh.AuthMethod{ssh.Password(user.password)}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + } +} + +func TestWrongPassword(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mywrongpassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + _, err := ssh.Dial("tcp", address, sshClientConfig) + assert.Error(t, err) + assert.Equal(t, `ssh: handshake failed: ssh: unable to authenticate, attempted methods [none password], no supported methods remain`, err.Error()) +} + +func TestAllowExecute(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-execute"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertExec(t, client) + assertPtyTerminal(t, client) + assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) + assertNoUnixLocalPortForwarding(t, client) +} + +func TestAllowTcpipForward(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-tcpip-forward"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) + assertNoUnixLocalPortForwarding(t, client) +} + +func TestAllowStreamlocalForward(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-streamlocal-forward"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoSftp(t, client) + assertUnixRemotePortForwarding(t, client) + assertNoUnixLocalPortForwarding(t, client) +} + +func TestAllowDirectTcpip(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-direct-tcpip"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) + assertNoUnixLocalPortForwarding(t, client) +} + +func TestAllowDirectStreamlocal(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-direct-streamlocal"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoSftp(t, client) + assertNoUnixRemotePortForwarding(t, client) + assertUnixLocalPortForwarding(t, client) +} + +func TestAllowSftp(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword", "--allow-sftp"}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + rootCmd.ExecuteContext(ctx) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + client, err := ssh.Dial("tcp", address, sshClientConfig) + assert.NoError(t, err) + defer client.Close() + assert.NoError(t, err) + assertNoRemotePortForwarding(t, client) + assertNoLocalPortForwarding(t, client) + assertNoExec(t, client) + assertNoPtyTerminal(t, client) + assertNoUnixRemotePortForwarding(t, client) + assertSftp(t, client) +} diff --git a/cmd/test_util_test.go b/cmd/test_util_test.go new file mode 100644 index 0000000..9d7e5ab --- /dev/null +++ b/cmd/test_util_test.go @@ -0,0 +1,283 @@ +package cmd + +import ( + "bytes" + "github.com/google/uuid" + "github.com/pkg/sftp" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" + "io" + "net" + "os" + "os/exec" + "path" + "strconv" + "testing" + "time" +) + +func getAvailableTcpPort() int { + ln, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + defer ln.Close() + return ln.Addr().(*net.TCPAddr).Port +} + +func waitTCPServer(port int) { + for { + conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + if err == nil { + conn.Close() + break + } + } +} + +func assertExec(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + whoamiBytes, err := session.Output("whoami") + assert.NoError(t, err) + expectedWhoamiBytes, err := exec.Command("whoami").Output() + assert.NoError(t, err) + assert.Equal(t, string(expectedWhoamiBytes), string(whoamiBytes)) +} + +func assertNoExec(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + _, err = session.Output("whoami") + assert.Error(t, err) + assert.Equal(t, "ssh: command whoami failed", err.Error()) +} + +func assertPtyTerminal(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + + err = session.RequestPty("xterm", 100, 200, ssh.TerminalModes{}) + assert.NoError(t, err) + stdin, err := session.StdinPipe() + assert.NoError(t, err) + _, err = stdin.Write([]byte("echo helloworldviapty\r")) + assert.NoError(t, err) + stdout, err := session.StdoutPipe() + assert.NoError(t, err) + stdoutBytesChan := make(chan []byte) + go func() { + var buff bytes.Buffer + _, err := io.Copy(&buff, stdout) + assert.NoError(t, err) + stdoutBytesChan <- buff.Bytes() + }() + err = session.Shell() + assert.NoError(t, err) + time.Sleep(1 * time.Second) + session.Close() + stdoutBytes := <-stdoutBytesChan + assert.Contains(t, string(stdoutBytes), "helloworldviapty") +} + +func assertNoPtyTerminal(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + err = session.RequestPty("xterm", 100, 200, ssh.TerminalModes{}) + assert.Error(t, err) + assert.Equal(t, "ssh: pty-req failed", err.Error()) +} + +func assertLocalPortForwarding(t *testing.T, client *ssh.Client) { + var remoteTcpPort int + acceptedConnChan := make(chan net.Conn) + { + ln, err := net.Listen("tcp", ":0") + assert.NoError(t, err) + remoteTcpPort = ln.Addr().(*net.TCPAddr).Port + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + } + raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort} + conn, err := client.DialTCP("tcp", nil, raddr) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertNoLocalPortForwarding(t *testing.T, client *ssh.Client) { + raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + _, err := client.DialTCP("tcp", nil, raddr) + assert.Error(t, err) + assert.Equal(t, "ssh: rejected: administratively prohibited (direct-tcpip not allowed)", err.Error()) +} + +func assertUnixLocalPortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + acceptedConnChan := make(chan net.Conn) + { + ln, err := net.Listen("unix", remoteUnixSocket) + assert.NoError(t, err) + defer os.Remove(remoteUnixSocket) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + } + conn, err := client.Dial("unix", remoteUnixSocket) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertNoUnixLocalPortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + _, err := client.Dial("unix", remoteUnixSocket) + assert.Error(t, err) + assert.Equal(t, "ssh: rejected: administratively prohibited (direct-streamlocal (Unix domain socket) not allowed)", err.Error()) +} + +func assertRemotePortForwarding(t *testing.T, client *ssh.Client) { + remotePort := getAvailableTcpPort() + ln, err := client.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) + assert.NoError(t, err) + defer ln.Close() + acceptedConnChan := make(chan net.Conn) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(remotePort))) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertNoRemotePortForwarding(t *testing.T, client *ssh.Client) { + _, err := client.Listen("tcp", "127.0.0.1:5678") + assert.Error(t, err) + assert.Equal(t, "ssh: tcpip-forward request denied by peer", err.Error()) +} + +func assertUnixRemotePortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + ln, err := client.ListenUnix(remoteUnixSocket) + assert.NoError(t, err) + defer os.Remove(remoteUnixSocket) + defer ln.Close() + acceptedConnChan := make(chan net.Conn) + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + conn, err := net.Dial("unix", remoteUnixSocket) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func assertNoUnixRemotePortForwarding(t *testing.T, client *ssh.Client) { + remoteUnixSocket := path.Join(os.TempDir(), "test-unix-socket-"+uuid.New().String()) + _, err := client.ListenUnix(remoteUnixSocket) + assert.Error(t, err) + assert.Equal(t, "ssh: streamlocal-forward@openssh.com request denied by peer", err.Error()) +} + +func assertSftp(t *testing.T, client *ssh.Client) { + sftpClient, err := sftp.NewClient(client) + assert.NoError(t, err) + _, err = sftpClient.Getwd() + assert.NoError(t, err) +} + +func assertNoSftp(t *testing.T, client *ssh.Client) { + _, err := sftp.NewClient(client) + assert.Error(t, err) + assert.Equal(t, "ssh: subsystem request failed", err.Error()) +} diff --git a/go.mod b/go.mod index a261369..de21a8d 100644 --- a/go.mod +++ b/go.mod @@ -4,18 +4,22 @@ go 1.20 require ( github.com/creack/pty v1.1.18 + github.com/google/uuid v1.3.0 github.com/mattn/go-shellwords v1.0.12 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.5 github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.12.0 golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/fs v0.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.8.4 // indirect golang.org/x/sys v0.11.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d8f05e0..c0e1b80 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,9 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -40,6 +43,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main/main.go b/main/main.go index 6ac23fe..8730087 100644 --- a/main/main.go +++ b/main/main.go @@ -6,7 +6,7 @@ import ( ) func main() { - if err := cmd.RootCmd.Execute(); err != nil { + if err := cmd.RootCmd().Execute(); err != nil { os.Exit(-1) } } diff --git a/server.go b/server.go index f8f514b..7e1dedd 100644 --- a/server.go +++ b/server.go @@ -30,10 +30,12 @@ type Server struct { Logger *slog.Logger // Permissions - AllowTcpipForward bool - AllowDirectTcpip bool - AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. - AllowSftp bool + AllowTcpipForward bool + AllowDirectTcpip bool + AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. + AllowSftp bool + AllowStreamlocalForward bool + AllowDirectStreamlocal bool // TODO: DNS server ? } @@ -59,6 +61,12 @@ func (s *Server) handleChannel(shell string, newChannel ssh.NewChannel) { break } s.handleDirectTcpip(newChannel) + case "direct-streamlocal@openssh.com": + if !s.AllowDirectStreamlocal { + newChannel.Reject(ssh.Prohibited, "direct-streamlocal (Unix domain socket) not allowed") + break + } + s.handleDirectStreamlocal(newChannel) default: newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType())) } @@ -114,6 +122,8 @@ func (s *Server) handleSession(shell string, newChannel ssh.NewChannel) { } case "subsystem": s.handleSessionSubSystem(req, connection) + default: + s.Logger.Info("unknown request", "req_type", req.Type) } } } @@ -197,7 +207,7 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) { SourcePort uint32 } if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil { - s.Logger.Info("failed to parse message", "err", err) + s.Logger.Info("failed to parse direct-tcpip message", "err", err) return } channel, reqs, err := newChannel.Accept() @@ -227,6 +237,44 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) { return } +// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L52 +func (s *Server) handleDirectStreamlocal(newChannel ssh.NewChannel) { + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L237 + var msg struct { + SocketPath string + Reserved0 string + Reserved1 uint32 + } + if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil { + s.Logger.Info("failed to parse direct-streamlocal message", "err", err) + return + } + channel, reqs, err := newChannel.Accept() + if err != nil { + s.Logger.Info("failed to accept", "err", err) + return + } + go ssh.DiscardRequests(reqs) + conn, err := net.Dial("unix", msg.SocketPath) + if err != nil { + s.Logger.Info("failed to dial", "err", err) + channel.Close() + return + } + var closeOnce sync.Once + closer := func() { + channel.Close() + conn.Close() + } + go func() { + io.Copy(channel, conn) + closeOnce.Do(closer) + }() + io.Copy(conn, channel) + closeOnce.Do(closer) + return +} + // ======================= // parseDims extracts terminal dimensions (width x height) from the provided buffer. @@ -266,7 +314,20 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh. req.Reply(false, nil) break } - s.handleTcpipForward(sshConn, req) + go func() { + s.handleTcpipForward(sshConn, req) + }() + case "streamlocal-forward@openssh.com": + if !s.AllowStreamlocalForward { + s.Logger.Info("streamlocal-forward not allowed") + req.Reply(false, nil) + break + } + go func() { + s.handleStreamlocalForward(sshConn, req) + }() + // TODO: support cancel-tcpip-forward + // TODO: support cancel-streamlocal-forward@openssh.com default: // discard if req.WantReply { @@ -292,6 +353,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { req.Reply(false, nil) return } + req.Reply(true, nil) go func() { sshConn.Wait() ln.Close() @@ -300,6 +362,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { for { conn, err := ln.Accept() if err != nil { + s.Logger.Info("failed to accept", "err", err) return } var replyMsg struct { @@ -310,6 +373,14 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { } replyMsg.Addr = msg.Addr replyMsg.Port = msg.Port + originatorAddr, originatorPortStr, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err == nil { + originatorPort, _ := strconv.Atoi(originatorPortStr) + replyMsg.OriginatorAddr = originatorAddr + replyMsg.OriginatorPort = uint32(originatorPort) + } else { + s.Logger.Error("failed to split remote address", "remote_address", conn.RemoteAddr()) + } go func() { channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&replyMsg)) @@ -332,3 +403,59 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { }() } } + +// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34 +func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) { + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272 + var msg struct { + SocketPath string + } + if err := ssh.Unmarshal(req.Payload, &msg); err != nil { + req.Reply(false, nil) + return + } + ln, err := net.Listen("unix", msg.SocketPath) + if err != nil { + req.Reply(false, nil) + return + } + req.Reply(true, nil) + go func() { + sshConn.Wait() + ln.Close() + s.Logger.Info("connection closed", "address", ln.Addr().String()) + }() + for { + conn, err := ln.Accept() + if err != nil { + s.Logger.Info("failed to accept", "err", err) + return + } + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L255 + var replyMsg struct { + SocketPath string + Reserved string + } + replyMsg.SocketPath = msg.SocketPath + + go func() { + channel, reqs, err := sshConn.OpenChannel("forwarded-streamlocal@openssh.com", ssh.Marshal(&replyMsg)) + if err != nil { + req.Reply(false, nil) + conn.Close() + return + } + go ssh.DiscardRequests(reqs) + go func() { + io.Copy(channel, conn) + conn.Close() + channel.Close() + }() + go func() { + io.Copy(conn, channel) + conn.Close() + channel.Close() + }() + }() + } +} diff --git a/version/version.go b/version/version.go index fafddf1..ee59a7c 100644 --- a/version/version.go +++ b/version/version.go @@ -1,3 +1,3 @@ package version -const Version = "0.2.1" +const Version = "0.3.0"