diff --git a/cmd/root.go b/cmd/root.go index ae5d21b..a2d6b66 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -22,11 +22,12 @@ type flagType struct { sshShell string sshUsers []string - allowTcpipForward bool - allowDirectTcpip bool - allowExecute bool - allowSftp bool - allowDirectStreamlocal bool + allowTcpipForward bool + allowDirectTcpip bool + allowExecute bool + allowSftp bool + allowStreamlocalForward bool + allowDirectStreamlocal bool } type permissionFlagType = struct { @@ -50,6 +51,7 @@ func RootCmd() *cobra.Command { {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, {name: "execute", flagPtr: &flag.allowExecute}, {name: "sftp", flagPtr: &flag.allowSftp}, + {name: "streamlocal-forward", flagPtr: &flag.allowStreamlocalForward}, {name: "direct-streamlocal", flagPtr: &flag.allowDirectStreamlocal}, } rootCmd := cobra.Command{ @@ -76,6 +78,7 @@ func RootCmd() *cobra.Command { rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") + rootCmd.PersistentFlags().BoolVarP(&flag.allowStreamlocalForward, "allow-streamlocal-forward", "", false, "client can use Unix domain socket remote forwarding") rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectStreamlocal, "allow-direct-streamlocal", "", false, "client can use Unix domain socket local forwarding") return &rootCmd @@ -102,12 +105,13 @@ func rootRunEWithExtra(cmd *cobra.Command, args []string, flag *flagType, allPer } sshServer := &handy_sshd.Server{ - Logger: logger, - AllowTcpipForward: flag.allowTcpipForward, - AllowDirectTcpip: flag.allowDirectTcpip, - AllowExecute: flag.allowExecute, - AllowSftp: flag.allowSftp, - AllowDirectStreamlocal: flag.allowDirectStreamlocal, + Logger: logger, + AllowTcpipForward: flag.allowTcpipForward, + AllowDirectTcpip: flag.allowDirectTcpip, + AllowExecute: flag.allowExecute, + AllowSftp: flag.allowSftp, + AllowStreamlocalForward: flag.allowStreamlocalForward, + AllowDirectStreamlocal: flag.allowDirectStreamlocal, } var sshUsers []sshUser for _, u := range flag.sshUsers { diff --git a/server.go b/server.go index a6a0162..a8e6358 100644 --- a/server.go +++ b/server.go @@ -30,11 +30,12 @@ type Server struct { Logger *slog.Logger // Permissions - AllowTcpipForward bool - AllowDirectTcpip bool - AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. - AllowSftp bool - AllowDirectStreamlocal bool + AllowTcpipForward bool + AllowDirectTcpip bool + AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. + AllowSftp bool + AllowStreamlocalForward bool + AllowDirectStreamlocal bool // TODO: DNS server ? } @@ -312,7 +313,13 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh. break } s.handleTcpipForward(sshConn, req) - // TODO: support: streamlocal-forward@openssh.com https://github.com/golang/crypto/blob/master/ssh/streamlocal.go + case "streamlocal-forward@openssh.com": + if !s.AllowStreamlocalForward { + s.Logger.Info("streamlocal-forward not allowed") + req.Reply(false, nil) + break + } + s.handleStreamlocalForward(sshConn, req) default: // discard if req.WantReply { @@ -347,6 +354,7 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { for { conn, err := ln.Accept() if err != nil { + s.Logger.Info("failed to accept", "err", err) return } var replyMsg struct { @@ -387,3 +395,59 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { }() } } + +// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34 +func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) { + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272 + var msg struct { + SocketPath string + } + if err := ssh.Unmarshal(req.Payload, &msg); err != nil { + req.Reply(false, nil) + return + } + ln, err := net.Listen("unix", msg.SocketPath) + if err != nil { + req.Reply(false, nil) + return + } + req.Reply(true, nil) + go func() { + sshConn.Wait() + ln.Close() + s.Logger.Info("connection closed", "address", ln.Addr().String()) + }() + for { + conn, err := ln.Accept() + if err != nil { + s.Logger.Info("failed to accept", "err", err) + return + } + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L255 + var replyMsg struct { + SocketPath string + Reserved string + } + replyMsg.SocketPath = msg.SocketPath + + go func() { + channel, reqs, err := sshConn.OpenChannel("forwarded-streamlocal@openssh.com", ssh.Marshal(&replyMsg)) + if err != nil { + req.Reply(false, nil) + conn.Close() + return + } + go ssh.DiscardRequests(reqs) + go func() { + io.Copy(channel, conn) + conn.Close() + channel.Close() + }() + go func() { + io.Copy(conn, channel) + conn.Close() + channel.Close() + }() + }() + } +}