From aac75f81a6ba5b1c3a60678a7cbeb3d3ff1dc12e Mon Sep 17 00:00:00 2001 From: Ryo Ota Date: Thu, 10 Aug 2023 22:14:43 +0900 Subject: [PATCH] test RootCmd --- cmd/root.go | 243 +++++++++++++++++++++++++---------------------- cmd/root_test.go | 129 +++++++++++++++++++++++++ go.mod | 5 +- go.sum | 2 + main/main.go | 2 +- server.go | 1 + 6 files changed, 264 insertions(+), 118 deletions(-) create mode 100644 cmd/root_test.go diff --git a/cmd/root.go b/cmd/root.go index 9581e71..1a89961 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -13,7 +13,7 @@ import ( "strings" ) -var flag struct { +type flagType struct { //dnsServer string showsVersion bool sshHost string @@ -28,14 +28,9 @@ var flag struct { allowSftp bool } -var allPermissionFlags = []struct { +type permissionFlagType = struct { name string flagPtr *bool -}{ - {name: "tcpip-forward", flagPtr: &flag.allowTcpipForward}, - {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, - {name: "execute", flagPtr: &flag.allowExecute}, - {name: "sftp", flagPtr: &flag.allowSftp}, } type sshUser struct { @@ -45,136 +40,152 @@ type sshUser struct { func init() { cobra.OnInitialize() - RootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version") - RootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") - RootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") - // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) - RootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket") - RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") - //RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") - RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) - - // Permission flags - RootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding") - RootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") - RootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") - RootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") } -var RootCmd = &cobra.Command{ - Use: os.Args[0], - Short: "handy-sshd", - Long: "Portable SSH server", - SilenceUsage: true, - RunE: func(cmd *cobra.Command, args []string) error { - if flag.showsVersion { - fmt.Println(version.Version) - return nil - } - logger := slog.Default() +func RootCmd() *cobra.Command { + var flag flagType + allPermissionFlags := []permissionFlagType{ + {name: "tcpip-forward", flagPtr: &flag.allowTcpipForward}, + {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, + {name: "execute", flagPtr: &flag.allowExecute}, + {name: "sftp", flagPtr: &flag.allowSftp}, + } + rootCmd := cobra.Command{ + Use: os.Args[0], + Short: "handy-sshd", + Long: "Portable SSH server", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { + return rootRunEWithExtra(cmd, args, &flag, allPermissionFlags) + }, + } - // Allow all permissions if all permission is not set - { - allPermissionFalse := true + rootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version") + rootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") + rootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") + // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) + rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket") + rootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") + //rootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") + rootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) + + // Permission flags + rootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding") + rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") + rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") + rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") + + return &rootCmd +} + +func rootRunEWithExtra(cmd *cobra.Command, args []string, flag *flagType, allPermissionFlags []permissionFlagType) error { + if flag.showsVersion { + fmt.Fprintln(cmd.OutOrStdout(), version.Version) + return nil + } + logger := slog.Default() + + // Allow all permissions if all permission is not set + { + allPermissionFalse := true + for _, permissionFlag := range allPermissionFlags { + allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr + } + if allPermissionFalse { for _, permissionFlag := range allPermissionFlags { - allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr - } - if allPermissionFalse { - for _, permissionFlag := range allPermissionFlags { - *permissionFlag.flagPtr = true - } + *permissionFlag.flagPtr = true } } + } - sshServer := &handy_sshd.Server{ - Logger: logger, - AllowTcpipForward: flag.allowTcpipForward, - AllowDirectTcpip: flag.allowDirectTcpip, - AllowExecute: flag.allowExecute, - AllowSftp: flag.allowSftp, + sshServer := &handy_sshd.Server{ + Logger: logger, + AllowTcpipForward: flag.allowTcpipForward, + AllowDirectTcpip: flag.allowDirectTcpip, + AllowExecute: flag.allowExecute, + AllowSftp: flag.allowSftp, + } + var sshUsers []sshUser + for _, u := range flag.sshUsers { + splits := strings.SplitN(u, ":", 2) + if len(splits) != 2 { + return fmt.Errorf("invalid user format: %s", u) } - var sshUsers []sshUser - for _, u := range flag.sshUsers { - splits := strings.SplitN(u, ":", 2) - if len(splits) != 2 { - return fmt.Errorf("invalid user format: %s", u) - } - sshUsers = append(sshUsers, sshUser{name: splits[0], password: splits[1]}) - } - if len(sshUsers) == 0 { - return fmt.Errorf(`No user specified + sshUsers = append(sshUsers, sshUser{name: splits[0], password: splits[1]}) + } + if len(sshUsers) == 0 { + return fmt.Errorf(`No user specified e.g. --user "john:mypassword" e.g. --user "john:"`) - } - // (base: https://gist.github.com/jpillora/b480fde82bff51a06238) - sshConfig := &ssh.ServerConfig{ - //Define a function to run when a client attempts a password login - PasswordCallback: func(metadata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { - for _, user := range sshUsers { - // No auth required - if user.name == metadata.User() && user.password == string(pass) { - return nil, nil - } + } + // (base: https://gist.github.com/jpillora/b480fde82bff51a06238) + sshConfig := &ssh.ServerConfig{ + //Define a function to run when a client attempts a password login + PasswordCallback: func(metadata ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + for _, user := range sshUsers { + // No auth required + if user.name == metadata.User() && user.password == string(pass) { + return nil, nil } - return nil, fmt.Errorf("password rejected for %q", metadata.User()) - }, - NoClientAuth: true, - NoClientAuthCallback: func(metadata ssh.ConnMetadata) (*ssh.Permissions, error) { - for _, user := range sshUsers { - // No auth required - if user.name == metadata.User() && user.password == "" { - return nil, nil - } + } + return nil, fmt.Errorf("password rejected for %q", metadata.User()) + }, + NoClientAuth: true, + NoClientAuthCallback: func(metadata ssh.ConnMetadata) (*ssh.Permissions, error) { + for _, user := range sshUsers { + // No auth required + if user.name == metadata.User() && user.password == "" { + return nil, nil } - return nil, fmt.Errorf("%s auth required", metadata.User()) - }, - } - // TODO: specify priv_key by flags - pri, err := ssh.ParsePrivateKey([]byte(defaultHostKeyPem)) + } + return nil, fmt.Errorf("%s auth required", metadata.User()) + }, + } + // TODO: specify priv_key by flags + pri, err := ssh.ParsePrivateKey([]byte(defaultHostKeyPem)) + if err != nil { + return err + } + sshConfig.AddHostKey(pri) + + var ln net.Listener + if flag.sshUnixSocket == "" { + address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) + ln, err = net.Listen("tcp", address) if err != nil { return err } - sshConfig.AddHostKey(pri) - - var ln net.Listener - if flag.sshUnixSocket == "" { - address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) - ln, err = net.Listen("tcp", address) - if err != nil { - return err - } - logger.Info(fmt.Sprintf("listening on %s...", address)) - } else { - ln, err = net.Listen("unix", flag.sshUnixSocket) - if err != nil { - return err - } - logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket)) + logger.Info(fmt.Sprintf("listening on %s...", address)) + } else { + ln, err = net.Listen("unix", flag.sshUnixSocket) + if err != nil { + return err } - defer ln.Close() + logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket)) + } + defer ln.Close() - showPermissions(logger) + showPermissions(logger, allPermissionFlags) - for { - conn, err := ln.Accept() - if err != nil { - logger.Error("failed to accept TCP connection", "err", err) - continue - } - sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig) - if err != nil { - logger.Info("failed to handshake", "err", err) - conn.Close() - continue - } - logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion())) - go sshServer.HandleGlobalRequests(sshConn, reqs) - go sshServer.HandleChannels(flag.sshShell, chans) + for { + conn, err := ln.Accept() + if err != nil { + logger.Error("failed to accept TCP connection", "err", err) + continue } - }, + sshConn, chans, reqs, err := ssh.NewServerConn(conn, sshConfig) + if err != nil { + logger.Info("failed to handshake", "err", err) + conn.Close() + continue + } + logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion())) + go sshServer.HandleGlobalRequests(sshConn, reqs) + go sshServer.HandleChannels(flag.sshShell, chans) + } } -func showPermissions(logger *slog.Logger) { +func showPermissions(logger *slog.Logger, allPermissionFlags []permissionFlagType) { var allowedList []string var notAllowedList []string for _, permissionFlag := range allPermissionFlags { diff --git a/cmd/root_test.go b/cmd/root_test.go new file mode 100644 index 0000000..b0c0546 --- /dev/null +++ b/cmd/root_test.go @@ -0,0 +1,129 @@ +package cmd + +import ( + "bytes" + "github.com/nwtgck/handy-sshd/version" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ssh" + "io" + "net" + "os/exec" + "strconv" + "testing" +) + +func TestVersion(t *testing.T) { + rootCmd := RootCmd() + rootCmd.SetArgs([]string{"--version"}) + var stdoutBuf bytes.Buffer + rootCmd.SetOut(&stdoutBuf) + assert.NoError(t, rootCmd.Execute()) + assert.Equal(t, version.Version+"\n", stdoutBuf.String()) +} + +func TestZeroUsers(t *testing.T) { + rootCmd := RootCmd() + rootCmd.SetArgs([]string{}) + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + assert.Error(t, rootCmd.Execute()) + assert.Equal(t, `Error: No user specified +e.g. --user "john:mypassword" +e.g. --user "john:" +`, stderrBuf.String()) +} + +func TestAllPermissionsAllowed(t *testing.T) { + rootCmd := RootCmd() + port := getAvailableTcpPort() + rootCmd.SetArgs([]string{"--port", strconv.Itoa(port), "--user", "john:mypassword"}) + go func() { + var stderrBuf bytes.Buffer + rootCmd.SetErr(&stderrBuf) + assert.NoError(t, rootCmd.Execute()) + }() + waitTCPServer(port) + sshClientConfig := &ssh.ClientConfig{ + User: "john", + Auth: []ssh.AuthMethod{ssh.Password("mypassword")}, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + address := net.JoinHostPort("127.0.0.1", strconv.Itoa(port)) + + client, err := ssh.Dial("tcp", address, sshClientConfig) + defer client.Close() + assert.NoError(t, err) + assertExec(t, client) + assertLocalPortForwarding(t, client) +} + +func assertExec(t *testing.T, client *ssh.Client) { + session, err := client.NewSession() + assert.NoError(t, err) + defer session.Close() + whoamiBytes, err := session.Output("whoami") + assert.NoError(t, err) + expectedWhoamiBytes, err := exec.Command("whoami").Output() + assert.NoError(t, err) + assert.Equal(t, string(whoamiBytes), string(expectedWhoamiBytes)) +} + +func assertLocalPortForwarding(t *testing.T, client *ssh.Client) { + var remoteTcpPort int + acceptedConnChan := make(chan net.Conn) + { + ln, err := net.Listen("tcp", ":0") + assert.NoError(t, err) + remoteTcpPort = ln.Addr().(*net.TCPAddr).Port + go func() { + conn, err := ln.Accept() + assert.NoError(t, err) + acceptedConnChan <- conn + }() + } + raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: remoteTcpPort} + conn, err := client.DialTCP("tcp", nil, raddr) + assert.NoError(t, err) + defer conn.Close() + acceptedConn := <-acceptedConnChan + defer acceptedConn.Close() + { + localToRemote := [3]byte{1, 2, 3} + _, err = conn.Write(localToRemote[:]) + assert.NoError(t, err) + var buf [len(localToRemote)]byte + _, err = io.ReadFull(acceptedConn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, localToRemote) + } + { + remoteToLocal := [4]byte{10, 20, 30, 40} + _, err = acceptedConn.Write(remoteToLocal[:]) + assert.NoError(t, err) + var buf [len(remoteToLocal)]byte + _, err = io.ReadFull(conn, buf[:]) + assert.NoError(t, err) + assert.Equal(t, buf, remoteToLocal) + } +} + +func getAvailableTcpPort() int { + ln, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + defer ln.Close() + return ln.Addr().(*net.TCPAddr).Port +} + +func waitTCPServer(port int) { + for { + conn, err := net.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + if err == nil { + conn.Close() + break + } + } +} diff --git a/go.mod b/go.mod index a261369..d7e0a84 100644 --- a/go.mod +++ b/go.mod @@ -8,14 +8,17 @@ require ( github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.5 github.com/spf13/cobra v1.7.0 + github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.12.0 golang.org/x/exp v0.0.0-20230809150735-7b3493d9a819 ) require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/fs v0.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.8.4 // indirect golang.org/x/sys v0.11.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d8f05e0..b17c2a7 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,7 @@ github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -40,6 +41,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/term v0.11.0 h1:F9tnn/DA/Im8nCwm+fX+1/eBwi4qFjRT++MhtVC4ZX0= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main/main.go b/main/main.go index 6ac23fe..8730087 100644 --- a/main/main.go +++ b/main/main.go @@ -6,7 +6,7 @@ import ( ) func main() { - if err := cmd.RootCmd.Execute(); err != nil { + if err := cmd.RootCmd().Execute(); err != nil { os.Exit(-1) } } diff --git a/server.go b/server.go index f8f514b..2022c4b 100644 --- a/server.go +++ b/server.go @@ -267,6 +267,7 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh. break } s.handleTcpipForward(sshConn, req) + // TODO: support: streamlocal-forward@openssh.com https://github.com/golang/crypto/blob/master/ssh/streamlocal.go default: // discard if req.WantReply {