mirror of
https://github.com/nwtgck/handy-sshd.git
synced 2025-06-07 14:43:05 +00:00
support "cancel-tcpip-forward" and "cancel-streamlocal-forward@openssh.com"
This commit is contained in:
parent
445676f901
commit
3ef24a9954
2 changed files with 122 additions and 4 deletions
62
server.go
62
server.go
|
@ -15,6 +15,7 @@ import (
|
|||
"encoding/pem"
|
||||
"fmt"
|
||||
"github.com/mattn/go-shellwords"
|
||||
"github.com/nwtgck/handy-sshd/sync_generics"
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slog"
|
||||
|
@ -28,6 +29,7 @@ import (
|
|||
|
||||
type Server struct {
|
||||
Logger *slog.Logger
|
||||
bindAddressToListener sync_generics.Map[string, net.Listener]
|
||||
|
||||
// Permissions
|
||||
AllowTcpipForward bool
|
||||
|
@ -317,6 +319,10 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.
|
|||
go func() {
|
||||
s.handleTcpipForward(sshConn, req)
|
||||
}()
|
||||
case "cancel-tcpip-forward":
|
||||
go func() {
|
||||
s.cancelTcpipForward(req)
|
||||
}()
|
||||
case "streamlocal-forward@openssh.com":
|
||||
if !s.AllowStreamlocalForward {
|
||||
s.Logger.Info("streamlocal-forward not allowed")
|
||||
|
@ -326,8 +332,10 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.
|
|||
go func() {
|
||||
s.handleStreamlocalForward(sshConn, req)
|
||||
}()
|
||||
// TODO: support cancel-tcpip-forward
|
||||
// TODO: support cancel-streamlocal-forward@openssh.com
|
||||
case "cancel-streamlocal-forward@openssh.com":
|
||||
go func() {
|
||||
s.cancelStreamlocalForward(req)
|
||||
}()
|
||||
default:
|
||||
// discard
|
||||
if req.WantReply {
|
||||
|
@ -348,11 +356,13 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
|
|||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port))))
|
||||
address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))
|
||||
ln, err := net.Listen("tcp", address)
|
||||
if err != nil {
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
s.bindAddressToListener.Store(address, ln)
|
||||
req.Reply(true, nil)
|
||||
go func() {
|
||||
sshConn.Wait()
|
||||
|
@ -404,6 +414,29 @@ func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
|
||||
func (s *Server) cancelTcpipForward(req *ssh.Request) {
|
||||
var msg struct {
|
||||
Addr string
|
||||
Port uint32
|
||||
}
|
||||
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))
|
||||
ln, loaded := s.bindAddressToListener.LoadAndDelete(address)
|
||||
if !loaded {
|
||||
req.Reply(false, nil)
|
||||
s.Logger.Info("failed to find listener", "address", address)
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
req.Reply(false, nil)
|
||||
s.Logger.Info("failed to close", "err", err)
|
||||
}
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
|
||||
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34
|
||||
func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) {
|
||||
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272
|
||||
|
@ -419,6 +452,7 @@ func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Requ
|
|||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
s.bindAddressToListener.Store(msg.SocketPath, ln)
|
||||
req.Reply(true, nil)
|
||||
go func() {
|
||||
sshConn.Wait()
|
||||
|
@ -459,3 +493,25 @@ func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Requ
|
|||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) cancelStreamlocalForward(req *ssh.Request) {
|
||||
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L280
|
||||
var msg struct {
|
||||
SocketPath string
|
||||
}
|
||||
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
ln, loaded := s.bindAddressToListener.LoadAndDelete(msg.SocketPath)
|
||||
if !loaded {
|
||||
s.Logger.Info("failed to find listener", "address", msg.SocketPath)
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
req.Reply(false, nil)
|
||||
s.Logger.Info("failed to close", "err", err)
|
||||
}
|
||||
req.Reply(true, nil)
|
||||
}
|
||||
|
|
62
sync_generics/map.go
Normal file
62
sync_generics/map.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package sync_generics
|
||||
|
||||
import "sync"
|
||||
|
||||
type Map[K any, V any] struct {
|
||||
inner sync.Map
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
|
||||
deleted = m.inner.CompareAndDelete(key, old)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) CompareAndSwap(key K, old V, new V) bool {
|
||||
return m.inner.CompareAndSwap(key, old, new)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Delete(key K) {
|
||||
m.inner.Delete(key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Load(key K) (value V, ok bool) {
|
||||
_value, ok := m.inner.Load(key)
|
||||
value = nilSafeTypeAssertion[V](_value)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
|
||||
_value, loaded := m.inner.LoadAndDelete(key)
|
||||
value = nilSafeTypeAssertion[V](_value)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
||||
_actual, loaded := m.inner.LoadOrStore(key, value)
|
||||
actual = nilSafeTypeAssertion[V](_actual)
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Range(f func(key K, value V) bool) {
|
||||
m.inner.Range(func(key, value any) bool {
|
||||
return f(nilSafeTypeAssertion[K](key), nilSafeTypeAssertion[V](value))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Store(key K, value V) {
|
||||
m.inner.Store(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Swap(key K, value V) (previous V, loaded bool) {
|
||||
_previous, loaded := m.inner.Swap(key, value)
|
||||
previous = nilSafeTypeAssertion[V](_previous)
|
||||
return
|
||||
}
|
||||
|
||||
func nilSafeTypeAssertion[T any](value any) T {
|
||||
var zero T
|
||||
if value == nil {
|
||||
return zero
|
||||
}
|
||||
return value.(T)
|
||||
}
|
Loading…
Add table
Reference in a new issue