From 6c183ffac44acfaab9894d8638936b4c42dad0ce Mon Sep 17 00:00:00 2001 From: Ryo Ota Date: Fri, 11 Aug 2023 09:02:22 +0900 Subject: [PATCH] support Unix domain socket local forwarding --- cmd/root.go | 24 +++++++++++++---------- server.go | 55 ++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 64 insertions(+), 15 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 1a89961..ae5d21b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -22,10 +22,11 @@ type flagType struct { sshShell string sshUsers []string - allowTcpipForward bool - allowDirectTcpip bool - allowExecute bool - allowSftp bool + allowTcpipForward bool + allowDirectTcpip bool + allowExecute bool + allowSftp bool + allowDirectStreamlocal bool } type permissionFlagType = struct { @@ -49,6 +50,7 @@ func RootCmd() *cobra.Command { {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, {name: "execute", flagPtr: &flag.allowExecute}, {name: "sftp", flagPtr: &flag.allowSftp}, + {name: "direct-streamlocal", flagPtr: &flag.allowDirectStreamlocal}, } rootCmd := cobra.Command{ Use: os.Args[0], @@ -64,7 +66,7 @@ func RootCmd() *cobra.Command { rootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") rootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) - rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket") + rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix domain socket") rootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") //rootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") rootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) @@ -74,6 +76,7 @@ func RootCmd() *cobra.Command { rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") + rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectStreamlocal, "allow-direct-streamlocal", "", false, "client can use Unix domain socket local forwarding") return &rootCmd } @@ -99,11 +102,12 @@ func rootRunEWithExtra(cmd *cobra.Command, args []string, flag *flagType, allPer } sshServer := &handy_sshd.Server{ - Logger: logger, - AllowTcpipForward: flag.allowTcpipForward, - AllowDirectTcpip: flag.allowDirectTcpip, - AllowExecute: flag.allowExecute, - AllowSftp: flag.allowSftp, + Logger: logger, + AllowTcpipForward: flag.allowTcpipForward, + AllowDirectTcpip: flag.allowDirectTcpip, + AllowExecute: flag.allowExecute, + AllowSftp: flag.allowSftp, + AllowDirectStreamlocal: flag.allowDirectStreamlocal, } var sshUsers []sshUser for _, u := range flag.sshUsers { diff --git a/server.go b/server.go index 2defe06..a6a0162 100644 --- a/server.go +++ b/server.go @@ -30,10 +30,11 @@ type Server struct { Logger *slog.Logger // Permissions - AllowTcpipForward bool - AllowDirectTcpip bool - AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. - AllowSftp bool + AllowTcpipForward bool + AllowDirectTcpip bool + AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. + AllowSftp bool + AllowDirectStreamlocal bool // TODO: DNS server ? } @@ -59,6 +60,12 @@ func (s *Server) handleChannel(shell string, newChannel ssh.NewChannel) { break } s.handleDirectTcpip(newChannel) + case "direct-streamlocal@openssh.com": + if !s.AllowDirectStreamlocal { + newChannel.Reject(ssh.Prohibited, "direct-streamlocal (Unix domain socket) not allowed") + break + } + s.handleDirectStreamlocal(newChannel) default: newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType())) } @@ -197,7 +204,7 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) { SourcePort uint32 } if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil { - s.Logger.Info("failed to parse message", "err", err) + s.Logger.Info("failed to parse direct-tcpip message", "err", err) return } channel, reqs, err := newChannel.Accept() @@ -227,6 +234,44 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) { return } +// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L52 +func (s *Server) handleDirectStreamlocal(newChannel ssh.NewChannel) { + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L237 + var msg struct { + SocketPath string + Reserved0 string + Reserved1 uint32 + } + if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil { + s.Logger.Info("failed to parse direct-streamlocal message", "err", err) + return + } + channel, reqs, err := newChannel.Accept() + if err != nil { + s.Logger.Info("failed to accept", "err", err) + return + } + go ssh.DiscardRequests(reqs) + conn, err := net.Dial("unix", msg.SocketPath) + if err != nil { + s.Logger.Info("failed to dial", "err", err) + channel.Close() + return + } + var closeOnce sync.Once + closer := func() { + channel.Close() + conn.Close() + } + go func() { + io.Copy(channel, conn) + closeOnce.Do(closer) + }() + io.Copy(conn, channel) + closeOnce.Do(closer) + return +} + // ======================= // parseDims extracts terminal dimensions (width x height) from the provided buffer.