handy-sshd/server.go

517 lines
13 KiB
Go

// Copyright (c) 2021 Ryo Ota
// Released under the MIT License
// Copyright (c) 2020 Jaime Pillora <dev@jpillora.com>
// Released under the MIT License
// https://github.com/jpillora/sshd-lite/tree/master#mit-license
package handy_sshd
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/binary"
"encoding/pem"
"fmt"
"github.com/mattn/go-shellwords"
"github.com/nwtgck/handy-sshd/sync_generics"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slog"
"io"
"net"
"os"
"os/exec"
"strconv"
"sync"
)
type Server struct {
Logger *slog.Logger
bindAddressToListener sync_generics.Map[string, net.Listener]
// Permissions
AllowTcpipForward bool
AllowDirectTcpip bool
AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution.
AllowSftp bool
AllowStreamlocalForward bool
AllowDirectStreamlocal bool
// TODO: DNS server ?
}
type exitStatusMsg struct {
Status uint32
}
func (s *Server) HandleChannels(shell string, chans <-chan ssh.NewChannel) {
// Service the incoming Channel channel in go routine
for newChannel := range chans {
go s.handleChannel(shell, newChannel)
}
}
func (s *Server) handleChannel(shell string, newChannel ssh.NewChannel) {
switch newChannel.ChannelType() {
case "session":
s.handleSession(shell, newChannel)
case "direct-tcpip":
if !s.AllowDirectTcpip {
newChannel.Reject(ssh.Prohibited, "direct-tcpip not allowed")
break
}
s.handleDirectTcpip(newChannel)
case "direct-streamlocal@openssh.com":
if !s.AllowDirectStreamlocal {
newChannel.Reject(ssh.Prohibited, "direct-streamlocal (Unix domain socket) not allowed")
break
}
s.handleDirectStreamlocal(newChannel)
default:
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
}
}
func (s *Server) handleSession(shell string, newChannel ssh.NewChannel) {
// At this point, we have the opportunity to reject the client's
// request for another logical connection
connection, requests, err := newChannel.Accept()
if err != nil {
s.Logger.Info("Could not accept channel", "err", err)
return
}
var shf *os.File = nil
for req := range requests {
switch req.Type {
case "exec":
if !s.AllowExecute {
s.Logger.Info("execution not allowed (exec)")
req.Reply(false, nil)
break
}
s.handleExecRequest(req, connection)
case "shell":
// We only accept the default shell
// (i.e. no command in the Payload)
if len(req.Payload) == 0 {
req.Reply(true, nil)
}
case "pty-req":
if !s.AllowExecute {
s.Logger.Info("execution not allowed (pty-req)")
req.Reply(false, nil)
break
}
termLen := req.Payload[3]
w, h := parseDims(req.Payload[termLen+4:])
shf, err = s.createPty(shell, connection)
if err != nil {
req.Reply(false, nil)
return
}
setWinsize(shf, w, h)
// Responding true (OK) here will let the client
// know we have a pty ready for input
req.Reply(true, nil)
case "window-change":
w, h := parseDims(req.Payload)
if shf != nil {
setWinsize(shf, w, h)
}
case "subsystem":
s.handleSessionSubSystem(req, connection)
default:
s.Logger.Info("unsupported request", "req_type", req.Type)
}
}
}
func (s *Server) handleExecRequest(req *ssh.Request, connection ssh.Channel) {
var msg struct {
Command string
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
s.Logger.Info("failed to parse message in exec", "err", err)
return
}
cmdSlice, err := shellwords.Parse(msg.Command)
if err != nil {
return
}
cmd := exec.Command(cmdSlice[0], cmdSlice[1:]...)
stdin, err := cmd.StdinPipe()
if err != nil {
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return
}
stderr, err := cmd.StderrPipe()
if err != nil {
return
}
go io.Copy(stdin, connection)
go io.Copy(connection, stdout)
go io.Copy(connection, stderr)
req.Reply(true, nil)
var exitCode int
if err := cmd.Run(); err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
}
}
connection.SendRequest("exit-status", false, ssh.Marshal(exitStatusMsg{
Status: uint32(exitCode),
}))
connection.Close()
}
func (s *Server) handleSessionSubSystem(req *ssh.Request, connection ssh.Channel) {
// https://github.com/pkg/sftp/blob/42e9800606febe03f9cdf1d1283719af4a5e6456/examples/go-sftp-server/main.go#L111
if string(req.Payload[4:]) != "sftp" {
req.Reply(false, nil)
return
}
if !s.AllowSftp {
s.Logger.Info("sftp not allowed")
req.Reply(false, nil)
return
}
req.Reply(true, nil)
serverOptions := []sftp.ServerOption{
sftp.WithDebug(os.Stderr),
}
sftpServer, err := sftp.NewServer(connection, serverOptions...)
if err != nil {
s.Logger.Info("failed to create sftp server", "err", err)
return
}
if err := sftpServer.Serve(); err == io.EOF {
sftpServer.Close()
} else if err != nil {
s.Logger.Info("failed to serve sftp server", "err", err)
return
}
}
// (base: https://github.com/peertechde/zodiac/blob/110fdd2dfd27359546c1cd75a9fec5de2882bf42/pkg/server/server.go#L228)
func (s *Server) handleDirectTcpip(newChannel ssh.NewChannel) {
var msg struct {
RemoteAddr string
RemotePort uint32
SourceAddr string
SourcePort uint32
}
if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil {
s.Logger.Info("failed to parse direct-tcpip message", "err", err)
return
}
channel, reqs, err := newChannel.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
go ssh.DiscardRequests(reqs)
raddr := net.JoinHostPort(msg.RemoteAddr, strconv.Itoa(int(msg.RemotePort)))
conn, err := net.Dial("tcp", raddr)
if err != nil {
s.Logger.Info("failed to dial", "err", err)
channel.Close()
return
}
var closeOnce sync.Once
closer := func() {
channel.Close()
conn.Close()
}
go func() {
io.Copy(channel, conn)
closeOnce.Do(closer)
}()
io.Copy(conn, channel)
closeOnce.Do(closer)
return
}
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L52
func (s *Server) handleDirectStreamlocal(newChannel ssh.NewChannel) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L237
var msg struct {
SocketPath string
Reserved0 string
Reserved1 uint32
}
if err := ssh.Unmarshal(newChannel.ExtraData(), &msg); err != nil {
s.Logger.Info("failed to parse direct-streamlocal message", "err", err)
return
}
channel, reqs, err := newChannel.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
go ssh.DiscardRequests(reqs)
conn, err := net.Dial("unix", msg.SocketPath)
if err != nil {
s.Logger.Info("failed to dial", "err", err)
channel.Close()
return
}
var closeOnce sync.Once
closer := func() {
channel.Close()
conn.Close()
}
go func() {
io.Copy(channel, conn)
closeOnce.Do(closer)
}()
io.Copy(conn, channel)
closeOnce.Do(closer)
return
}
// =======================
// parseDims extracts terminal dimensions (width x height) from the provided buffer.
func parseDims(b []byte) (uint32, uint32) {
w := binary.BigEndian.Uint32(b)
h := binary.BigEndian.Uint32(b[4:])
return w, h
}
// ======================
func GenerateKey() ([]byte, error) {
var r io.Reader
r = rand.Reader
priv, err := rsa.GenerateKey(r, 2048)
if err != nil {
return nil, err
}
err = priv.Validate()
if err != nil {
return nil, err
}
b := x509.MarshalPKCS1PrivateKey(priv)
return pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: b}), nil
}
// Borrowed from https://github.com/creack/termios/blob/master/win/win.go
// ======================================================================
func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request) {
for req := range reqs {
switch req.Type {
case "tcpip-forward":
if !s.AllowTcpipForward {
s.Logger.Info("tcpip-forward not allowed")
req.Reply(false, nil)
break
}
go func() {
s.handleTcpipForward(sshConn, req)
}()
case "cancel-tcpip-forward":
go func() {
s.cancelTcpipForward(req)
}()
case "streamlocal-forward@openssh.com":
if !s.AllowStreamlocalForward {
s.Logger.Info("streamlocal-forward not allowed")
req.Reply(false, nil)
break
}
go func() {
s.handleStreamlocalForward(sshConn, req)
}()
case "cancel-streamlocal-forward@openssh.com":
go func() {
s.cancelStreamlocalForward(req)
}()
default:
// discard
if req.WantReply {
req.Reply(false, nil)
}
s.Logger.Info("request discarded", "request_type", req.Type)
}
}
}
// https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
func (s *Server) handleTcpipForward(sshConn *ssh.ServerConn, req *ssh.Request) {
var msg struct {
Addr string
Port uint32
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))
ln, err := net.Listen("tcp", address)
if err != nil {
req.Reply(false, nil)
return
}
s.bindAddressToListener.Store(address, ln)
req.Reply(true, nil)
go func() {
sshConn.Wait()
ln.Close()
s.Logger.Info("connection closed", "address", ln.Addr().String())
}()
for {
conn, err := ln.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
var replyMsg struct {
Addr string
Port uint32
OriginatorAddr string
OriginatorPort uint32
}
replyMsg.Addr = msg.Addr
replyMsg.Port = msg.Port
originatorAddr, originatorPortStr, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
originatorPort, _ := strconv.Atoi(originatorPortStr)
replyMsg.OriginatorAddr = originatorAddr
replyMsg.OriginatorPort = uint32(originatorPort)
} else {
s.Logger.Error("failed to split remote address", "remote_address", conn.RemoteAddr())
}
go func() {
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(&replyMsg))
if err != nil {
req.Reply(false, nil)
conn.Close()
return
}
go ssh.DiscardRequests(reqs)
go func() {
io.Copy(channel, conn)
conn.Close()
channel.Close()
}()
go func() {
io.Copy(conn, channel)
conn.Close()
channel.Close()
}()
}()
}
}
// https://datatracker.ietf.org/doc/html/rfc4254#section-7.1
func (s *Server) cancelTcpipForward(req *ssh.Request) {
var msg struct {
Addr string
Port uint32
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
address := net.JoinHostPort(msg.Addr, strconv.Itoa(int(msg.Port)))
ln, loaded := s.bindAddressToListener.LoadAndDelete(address)
if !loaded {
req.Reply(false, nil)
s.Logger.Info("failed to find listener", "address", address)
}
if err := ln.Close(); err != nil {
req.Reply(false, nil)
s.Logger.Info("failed to close", "err", err)
}
req.Reply(true, nil)
}
// client side: https://github.com/golang/crypto/blob/b4ddeeda5bc71549846db71ba23e83ecb26f36ed/ssh/streamlocal.go#L34
func (s *Server) handleStreamlocalForward(sshConn *ssh.ServerConn, req *ssh.Request) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L272
var msg struct {
SocketPath string
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
ln, err := net.Listen("unix", msg.SocketPath)
if err != nil {
req.Reply(false, nil)
return
}
s.bindAddressToListener.Store(msg.SocketPath, ln)
req.Reply(true, nil)
go func() {
sshConn.Wait()
ln.Close()
s.Logger.Info("connection closed", "address", ln.Addr().String())
}()
for {
conn, err := ln.Accept()
if err != nil {
s.Logger.Info("failed to accept", "err", err)
return
}
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L255
var replyMsg struct {
SocketPath string
Reserved string
}
replyMsg.SocketPath = msg.SocketPath
go func() {
channel, reqs, err := sshConn.OpenChannel("forwarded-streamlocal@openssh.com", ssh.Marshal(&replyMsg))
if err != nil {
req.Reply(false, nil)
conn.Close()
return
}
go ssh.DiscardRequests(reqs)
go func() {
io.Copy(channel, conn)
conn.Close()
channel.Close()
}()
go func() {
io.Copy(conn, channel)
conn.Close()
channel.Close()
}()
}()
}
}
func (s *Server) cancelStreamlocalForward(req *ssh.Request) {
// https://github.com/openssh/openssh-portable/blob/f9f18006678d2eac8b0c5a5dddf17ab7c50d1e9f/PROTOCOL#L280
var msg struct {
SocketPath string
}
if err := ssh.Unmarshal(req.Payload, &msg); err != nil {
req.Reply(false, nil)
return
}
ln, loaded := s.bindAddressToListener.LoadAndDelete(msg.SocketPath)
if !loaded {
s.Logger.Info("failed to find listener", "address", msg.SocketPath)
req.Reply(false, nil)
return
}
if err := ln.Close(); err != nil {
req.Reply(false, nil)
s.Logger.Info("failed to close", "err", err)
}
req.Reply(true, nil)
}