support "cancel-tcpip-forward" and "cancel-streamlocal-forward@openssh.com"

This commit is contained in:
Ryo Ota 2023-08-11 21:40:14 +09:00
parent 445676f901
commit 3ef24a9954
2 changed files with 122 additions and 4 deletions

View file

@ -15,6 +15,7 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"github.com/mattn/go-shellwords" "github.com/mattn/go-shellwords"
"github.com/nwtgck/handy-sshd/sync_generics"
"github.com/pkg/sftp" "github.com/pkg/sftp"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
@ -27,7 +28,8 @@ import (
) )
type Server struct { type Server struct {
Logger *slog.Logger Logger *slog.Logger
bindAddressToListener sync_generics.Map[string, net.Listener]
// Permissions // Permissions
AllowTcpipForward bool AllowTcpipForward bool
@ -317,6 +319,10 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.
go func() { go func() {
s.handleTcpipForward(sshConn, req) s.handleTcpipForward(sshConn, req)
}() }()
case "cancel-tcpip-forward":
go func() {
s.cancelTcpipForward(req)
}()
case "streamlocal-forward@openssh.com": case "streamlocal-forward@openssh.com":
if !s.AllowStreamlocalForward { if !s.AllowStreamlocalForward {
s.Logger.Info("streamlocal-forward not allowed") s.Logger.Info("streamlocal-forward not allowed")
@ -326,8 +332,10 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.
go func() { go func() {
s.handleStreamlocalForward(sshConn, req) s.handleStreamlocalForward(sshConn, req)
}() }()
// TODO: support cancel-tcpip-forward case "cancel-streamlocal-forward@openssh.com":
// TODO: support cancel-streamlocal-forward@openssh.com go func() {
s.cancelStreamlocalForward(req)
}()
default: default:
// discard // discard
if req.WantReply { if req.WantReply {
@ -348,11 +356,13 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
req.Reply(false, nil) req.Reply(false, nil)
return return
} }
ln, err := net.Listen("tcp", net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))) address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))
ln, err := net.Listen("tcp", address)
if err != nil { if err != nil {
req.Reply(false, nil) req.Reply(false, nil)
return return
} }
s.bindAddressToListener.Store(address, ln)
req.Reply(true, nil) req.Reply(true, nil)
go func() { go func() {
sshConn.Wait() sshConn.Wait()
@ -404,6 +414,29 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
} }
} }
// https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
func (s *Server) cancelTcpipForward(req *ssh.Request) {
var msg struct {
Addr string
Port uint32
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))
ln, loaded := s.bindAddressToListener.LoadAndDelete(address)
if !loaded {
req.Reply(false, nil)
s.Logger.Info("failed to find listener", "address", address)
}
if err := ln.Close(); err != nil {
req.Reply(false, nil)
s.Logger.Info("failed to close", "err", err)
}
req.Reply(true, nil)
}
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34 // client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34
func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) { func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272 // https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272
@ -419,6 +452,7 @@ func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Requ
req.Reply(false, nil) req.Reply(false, nil)
return return
} }
s.bindAddressToListener.Store(msg.SocketPath, ln)
req.Reply(true, nil) req.Reply(true, nil)
go func() { go func() {
sshConn.Wait() sshConn.Wait()
@ -459,3 +493,25 @@ func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Requ
}() }()
} }
} }
func (s *Server) cancelStreamlocalForward(req *ssh.Request) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L280
var msg struct {
SocketPath string
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
ln, loaded := s.bindAddressToListener.LoadAndDelete(msg.SocketPath)
if !loaded {
s.Logger.Info("failed to find listener", "address", msg.SocketPath)
req.Reply(false, nil)
return
}
if err := ln.Close(); err != nil {
req.Reply(false, nil)
s.Logger.Info("failed to close", "err", err)
}
req.Reply(true, nil)
}

62
sync_generics/map.go Normal file
View file

@ -0,0 +1,62 @@
package sync_generics
import "sync"
type Map[K any, V any] struct {
inner sync.Map
}
func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
deleted = m.inner.CompareAndDelete(key, old)
return
}
func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool {
return m.inner.CompareAndSwap(key, old, new)
}
func (m *Map[K, V]) Delete(key K) {
m.inner.Delete(key)
}
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
_value, ok := m.inner.Load(key)
value = nilSafeTypeAssertion[V](_value)
return
}
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
_value, loaded := m.inner.LoadAndDelete(key)
value = nilSafeTypeAssertion[V](_value)
return
}
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
_actual, loaded := m.inner.LoadOrStore(key, value)
actual = nilSafeTypeAssertion[V](_actual)
return
}
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
m.inner.Range(func(key, value any) bool {
return f(nilSafeTypeAssertion[K](key), nilSafeTypeAssertion[V](value))
})
}
func (m *Map[K, V]) Store(key K, value V) {
m.inner.Store(key, value)
}
func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) {
_previous, loaded := m.inner.Swap(key, value)
previous = nilSafeTypeAssertion[V](_previous)
return
}
func nilSafeTypeAssertion[T any](value any) T {
var zero T
if value == nil {
return zero
}
return value.(T)
}