support Unix domain socket local forwarding

This commit is contained in:
Ryo Ota 2023-08-11 09:02:22 +09:00
parent 663c9fae81
commit 6c183ffac4
2 changed files with 64 additions and 15 deletions

View file

@ -22,10 +22,11 @@ type flagType struct {
sshShell string sshShell string
sshUsers []string sshUsers []string
allowTcpipForward bool allowTcpipForward bool
allowDirectTcpip bool allowDirectTcpip bool
allowExecute bool allowExecute bool
allowSftp bool allowSftp bool
allowDirectStreamlocal bool
} }
type permissionFlagType = struct { type permissionFlagType = struct {
@ -49,6 +50,7 @@ func RootCmd() *cobra.Command {
{name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip}, {name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip},
{name: "execute", flagPtr: &flag.allowExecute}, {name: "execute", flagPtr: &flag.allowExecute},
{name: "sftp", flagPtr: &flag.allowSftp}, {name: "sftp", flagPtr: &flag.allowSftp},
{name: "direct-streamlocal", flagPtr: &flag.allowDirectStreamlocal},
} }
rootCmd := cobra.Command{ rootCmd := cobra.Command{
Use: os.Args[0], Use: os.Args[0],
@ -64,7 +66,7 @@ func RootCmd() *cobra.Command {
rootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") rootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)")
rootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") rootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port")
// NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html)
rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket") rootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix domain socket")
rootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") rootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell")
//rootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") //rootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)")
rootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) rootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`)
@ -74,6 +76,7 @@ func RootCmd() *cobra.Command {
rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy") rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy")
rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell") rootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell")
rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS") rootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS")
rootCmd.PersistentFlags().BoolVarP(&flag.allowDirectStreamlocal, "allow-direct-streamlocal", "", false, "client can use Unix domain socket local forwarding")
return &rootCmd return &rootCmd
} }
@ -99,11 +102,12 @@ func rootRunEWithExtra(cmd *cobra.Command, args []string, flag *flagType, allPer
} }
sshServer := &handy_sshd.Server{ sshServer := &handy_sshd.Server{
Logger: logger, Logger: logger,
AllowTcpipForward: flag.allowTcpipForward, AllowTcpipForward: flag.allowTcpipForward,
AllowDirectTcpip: flag.allowDirectTcpip, AllowDirectTcpip: flag.allowDirectTcpip,
AllowExecute: flag.allowExecute, AllowExecute: flag.allowExecute,
AllowSftp: flag.allowSftp, AllowSftp: flag.allowSftp,
AllowDirectStreamlocal: flag.allowDirectStreamlocal,
} }
var sshUsers []sshUser var sshUsers []sshUser
for _, u := range flag.sshUsers { for _, u := range flag.sshUsers {

View file

@ -30,10 +30,11 @@ type Server struct {
Logger *slog.Logger Logger *slog.Logger
// Permissions // Permissions
AllowTcpipForward bool AllowTcpipForward bool
AllowDirectTcpip bool AllowDirectTcpip bool
AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution. AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution.
AllowSftp bool AllowSftp bool
AllowDirectStreamlocal bool
// TODO: DNS server ? // TODO: DNS server ?
} }
@ -59,6 +60,12 @@ func (s *Server) handleChannel(shell string, newChannel ssh.NewChannel) {
break break
} }
s.handleDirectTcpip(newChannel) s.handleDirectTcpip(newChannel)
case "direct-streamlocal@openssh.com":
if !s.AllowDirectStreamlocal {
newChannel.Reject(ssh.Prohibited, "direct-streamlocal (Unix domain socket) not allowed")
break
}
s.handleDirectStreamlocal(newChannel)
default: default:
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType())) newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
} }
@ -197,7 +204,7 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) {
SourcePort uint32 SourcePort uint32
} }
if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil { if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil {
s.Logger.Info("failed to parse message", "err", err) s.Logger.Info("failed to parse direct-tcpip message", "err", err)
return return
} }
channel, reqs, err := newChannel.Accept() channel, reqs, err := newChannel.Accept()
@ -227,6 +234,44 @@ func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) {
return return
} }
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L52
func (s *Server) handleDirectStreamlocal(newChannel ssh.NewChannel) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L237
var msg struct {
SocketPath string
Reserved0 string
Reserved1 uint32
}
if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil {
s.Logger.Info("failed to parse direct-streamlocal message", "err", err)
return
}
channel, reqs, err := newChannel.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
go ssh.DiscardRequests(reqs)
conn, err := net.Dial("unix", msg.SocketPath)
if err != nil {
s.Logger.Info("failed to dial", "err", err)
channel.Close()
return
}
var closeOnce sync.Once
closer := func() {
channel.Close()
conn.Close()
}
go func() {
io.Copy(channel, conn)
closeOnce.Do(closer)
}()
io.Copy(conn, channel)
closeOnce.Do(closer)
return
}
// ======================= // =======================
// parseDims extracts terminal dimensions (width x height) from the provided buffer. // parseDims extracts terminal dimensions (width x height) from the provided buffer.