From 3ef24a9954e79579d343ac512a27e413858c3d43 Mon Sep 17 00:00:00 2001 From: Ryo Ota Date: Fri, 11 Aug 2023 21:40:14 +0900 Subject: [PATCH] support "cancel-tcpip-forward" and "cancel-streamlocal-forward@openssh.com" --- server.go | 64 +++++++++++++++++++++++++++++++++++++++++--- sync_generics/map.go | 62 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 sync_generics/map.go diff --git a/server.go b/server.go index 015dea9..882c316 100644 --- a/server.go +++ b/server.go @@ -15,6 +15,7 @@ import ( "encoding/pem" "fmt" "github.com/mattn/go-shellwords" + "github.com/nwtgck/handy-sshd/sync_generics" "github.com/pkg/sftp" "golang.org/x/crypto/ssh" "golang.org/x/exp/slog" @@ -27,7 +28,8 @@ import ( ) type Server struct { - Logger *slog.Logger + Logger *slog.Logger + bindAddressToListener sync_generics.Map[string, net.Listener] // Permissions AllowTcpipForward bool @@ -317,6 +319,10 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh. go func() { s.handleTcpipForward(sshConn, req) }() + case "cancel-tcpip-forward": + go func() { + s.cancelTcpipForward(req) + }() case "streamlocal-forward@openssh.com": if !s.AllowStreamlocalForward { s.Logger.Info("streamlocal-forward not allowed") @@ -326,8 +332,10 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh. go func() { s.handleStreamlocalForward(sshConn, req) }() - // TODO: support cancel-tcpip-forward - // TODO: support cancel-streamlocal-forward@openssh.com + case "cancel-streamlocal-forward@openssh.com": + go func() { + s.cancelStreamlocalForward(req) + }() default: // discard if req.WantReply { @@ -348,11 +356,13 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { req.Reply(false, nil) return } - ln, err := net.Listen("tcp", net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))) + address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port))) + ln, err := net.Listen("tcp", address) if err != nil { req.Reply(false, nil) return } + s.bindAddressToListener.Store(address, ln) req.Reply(true, nil) go func() { sshConn.Wait() @@ -404,6 +414,29 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) { } } +// https://datatracker.ietf.org/doc/html/rfc4254#section-7.1 +func (s *Server) cancelTcpipForward(req *ssh.Request) { + var msg struct { + Addr string + Port uint32 + } + if err := ssh.Unmarshal(req.Payload, &msg); err != nil { + req.Reply(false, nil) + return + } + address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port))) + ln, loaded := s.bindAddressToListener.LoadAndDelete(address) + if !loaded { + req.Reply(false, nil) + s.Logger.Info("failed to find listener", "address", address) + } + if err := ln.Close(); err != nil { + req.Reply(false, nil) + s.Logger.Info("failed to close", "err", err) + } + req.Reply(true, nil) +} + // client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34 func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) { // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272 @@ -419,6 +452,7 @@ func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Requ req.Reply(false, nil) return } + s.bindAddressToListener.Store(msg.SocketPath, ln) req.Reply(true, nil) go func() { sshConn.Wait() @@ -459,3 +493,25 @@ func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Requ }() } } + +func (s *Server) cancelStreamlocalForward(req *ssh.Request) { + // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L280 + var msg struct { + SocketPath string + } + if err := ssh.Unmarshal(req.Payload, &msg); err != nil { + req.Reply(false, nil) + return + } + ln, loaded := s.bindAddressToListener.LoadAndDelete(msg.SocketPath) + if !loaded { + s.Logger.Info("failed to find listener", "address", msg.SocketPath) + req.Reply(false, nil) + return + } + if err := ln.Close(); err != nil { + req.Reply(false, nil) + s.Logger.Info("failed to close", "err", err) + } + req.Reply(true, nil) +} diff --git a/sync_generics/map.go b/sync_generics/map.go new file mode 100644 index 0000000..ab31102 --- /dev/null +++ b/sync_generics/map.go @@ -0,0 +1,62 @@ +package sync_generics + +import "sync" + +type Map[K any, V any] struct { + inner sync.Map +} + +func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + deleted = m.inner.CompareAndDelete(key, old) + return +} + +func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool { + return m.inner.CompareAndSwap(key, old, new) +} + +func (m *Map[K, V]) Delete(key K) { + m.inner.Delete(key) +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + _value, ok := m.inner.Load(key) + value = nilSafeTypeAssertion[V](_value) + return +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + _value, loaded := m.inner.LoadAndDelete(key) + value = nilSafeTypeAssertion[V](_value) + return +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + _actual, loaded := m.inner.LoadOrStore(key, value) + actual = nilSafeTypeAssertion[V](_actual) + return +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.inner.Range(func(key, value any) bool { + return f(nilSafeTypeAssertion[K](key), nilSafeTypeAssertion[V](value)) + }) +} + +func (m *Map[K, V]) Store(key K, value V) { + m.inner.Store(key, value) +} + +func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) { + _previous, loaded := m.inner.Swap(key, value) + previous = nilSafeTypeAssertion[V](_previous) + return +} + +func nilSafeTypeAssertion[T any](value any) T { + var zero T + if value == nil { + return zero + } + return value.(T) +}