mirror of
https://github.com/nwtgck/handy-sshd.git
synced 2025-06-07 22:53:05 +00:00
implement permissions
This commit is contained in:
parent
0796e5e2d9
commit
1143b17071
2 changed files with 100 additions and 4 deletions
66
cmd/root.go
66
cmd/root.go
|
@ -21,6 +21,21 @@ var flag struct {
|
|||
sshUnixSocket string
|
||||
sshShell string
|
||||
sshUsers []string
|
||||
|
||||
allowTcpipForward bool
|
||||
allowDirectTcpip bool
|
||||
allowExecute bool
|
||||
allowSftp bool
|
||||
}
|
||||
|
||||
var allPermissionFlags = []struct {
|
||||
name string
|
||||
flagPtr *bool
|
||||
}{
|
||||
{name: "tcpip-forward", flagPtr: &flag.allowTcpipForward},
|
||||
{name: "direct-tcpip", flagPtr: &flag.allowDirectTcpip},
|
||||
{name: "execute", flagPtr: &flag.allowExecute},
|
||||
{name: "sftp", flagPtr: &flag.allowSftp},
|
||||
}
|
||||
|
||||
type sshUser struct {
|
||||
|
@ -38,6 +53,12 @@ func init() {
|
|||
RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell")
|
||||
//RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)")
|
||||
RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`)
|
||||
|
||||
// Permission flags
|
||||
RootCmd.PersistentFlags().BoolVarP(&flag.allowTcpipForward, "allow-tcpip-forward", "", false, "client can use remote forwarding")
|
||||
RootCmd.PersistentFlags().BoolVarP(&flag.allowDirectTcpip, "allow-direct-tcpip", "", false, "client can use local forwarding and SOCKS proxy")
|
||||
RootCmd.PersistentFlags().BoolVarP(&flag.allowExecute, "allow-execute", "", false, "client can use shell/interactive shell")
|
||||
RootCmd.PersistentFlags().BoolVarP(&flag.allowSftp, "allow-sftp", "", false, "client can use SFTP and SSHFS")
|
||||
}
|
||||
|
||||
var RootCmd = &cobra.Command{
|
||||
|
@ -51,8 +72,26 @@ var RootCmd = &cobra.Command{
|
|||
return nil
|
||||
}
|
||||
logger := slog.Default()
|
||||
|
||||
// Allow all permissions if all permission is not set
|
||||
{
|
||||
allPermissionFalse := true
|
||||
for _, permissionFlag := range allPermissionFlags {
|
||||
allPermissionFalse = allPermissionFalse && !*permissionFlag.flagPtr
|
||||
}
|
||||
if allPermissionFalse {
|
||||
for _, permissionFlag := range allPermissionFlags {
|
||||
*permissionFlag.flagPtr = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sshServer := &handy_sshd.Server{
|
||||
Logger: logger,
|
||||
Logger: logger,
|
||||
AllowTcpipForward: flag.allowTcpipForward,
|
||||
AllowDirectTcpip: flag.allowDirectTcpip,
|
||||
AllowExecute: flag.allowExecute,
|
||||
AllowSftp: flag.allowSftp,
|
||||
}
|
||||
var sshUsers []sshUser
|
||||
for _, u := range flag.sshUsers {
|
||||
|
@ -109,6 +148,9 @@ var RootCmd = &cobra.Command{
|
|||
logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket))
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
showPermissions(logger)
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
|
@ -121,9 +163,29 @@ var RootCmd = &cobra.Command{
|
|||
conn.Close()
|
||||
continue
|
||||
}
|
||||
logger.Info("new SSH connection", "client_version", string(sshConn.ClientVersion()))
|
||||
logger.Info("new SSH connection", "remote_address", sshConn.RemoteAddr(), "client_version", string(sshConn.ClientVersion()))
|
||||
go sshServer.HandleGlobalRequests(sshConn, reqs)
|
||||
go sshServer.HandleChannels(flag.sshShell, chans)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func showPermissions(logger *slog.Logger) {
|
||||
var allowedList []string
|
||||
var notAllowedList []string
|
||||
for _, permissionFlag := range allPermissionFlags {
|
||||
if *permissionFlag.flagPtr {
|
||||
allowedList = append(allowedList, `"`+permissionFlag.name+`"`)
|
||||
} else {
|
||||
notAllowedList = append(notAllowedList, `"`+permissionFlag.name+`"`)
|
||||
}
|
||||
}
|
||||
showList := func(l []string) string {
|
||||
if len(l) == 0 {
|
||||
return "none"
|
||||
}
|
||||
return strings.Join(l, ", ")
|
||||
}
|
||||
logger.Info(fmt.Sprintf("allowed: %s", showList(allowedList)))
|
||||
logger.Info(fmt.Sprintf("NOT allowed: %s", showList(notAllowedList)))
|
||||
}
|
||||
|
|
38
server.go
38
server.go
|
@ -28,6 +28,13 @@ import (
|
|||
|
||||
type Server struct {
|
||||
Logger *slog.Logger
|
||||
|
||||
// Permissions
|
||||
AllowTcpipForward bool
|
||||
AllowDirectTcpip bool
|
||||
AllowExecute bool // this should not be split into "allow-exec" and "allow-pty-req" for now because "pty-req" can be used not for shell execution.
|
||||
AllowSftp bool
|
||||
|
||||
// TODO: DNS server ?
|
||||
}
|
||||
|
||||
|
@ -47,6 +54,10 @@ func (s *Server) handleChannel(shell string, newChannel ssh.NewChannel) {
|
|||
case "session":
|
||||
s.handleSession(shell, newChannel)
|
||||
case "direct-tcpip":
|
||||
if !s.AllowDirectTcpip {
|
||||
newChannel.Reject(ssh.Prohibited, "direct-tcpip not allowed")
|
||||
break
|
||||
}
|
||||
s.handleDirectTcpip(newChannel)
|
||||
default:
|
||||
newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", newChannel.ChannelType()))
|
||||
|
@ -67,6 +78,11 @@ func (s *Server) handleSession(shell string, newChannel ssh.NewChannel) {
|
|||
for req := range requests {
|
||||
switch req.Type {
|
||||
case "exec":
|
||||
if !s.AllowExecute {
|
||||
s.Logger.Info("execution not allowed (exec)")
|
||||
req.Reply(false, nil)
|
||||
break
|
||||
}
|
||||
s.handleExecRequest(req, connection)
|
||||
case "shell":
|
||||
// We only accept the default shell
|
||||
|
@ -75,6 +91,11 @@ func (s *Server) handleSession(shell string, newChannel ssh.NewChannel) {
|
|||
req.Reply(true, nil)
|
||||
}
|
||||
case "pty-req":
|
||||
if !s.AllowExecute {
|
||||
s.Logger.Info("execution not allowed (pty-req)")
|
||||
req.Reply(false, nil)
|
||||
break
|
||||
}
|
||||
termLen := req.Payload[3]
|
||||
w, h := parseDims(req.Payload[termLen+4:])
|
||||
shf, err = s.createPty(shell, connection)
|
||||
|
@ -140,9 +161,17 @@ func (s *Server) handleExecRequest(req *ssh.Request, connection ssh.Channel) {
|
|||
|
||||
func (s *Server) handleSessionSubSystem(req *ssh.Request, connection ssh.Channel) {
|
||||
// https://github.com/pkg/sftp/blob/42e9800606febe03f9cdf1d1283719af4a5e6456/examples/go-sftp-server/main.go#L111
|
||||
ok := string(req.Payload[4:]) == "sftp"
|
||||
req.Reply(ok, nil)
|
||||
if string(req.Payload[4:]) != "sftp" {
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
if !s.AllowSftp {
|
||||
s.Logger.Info("sftp not allowed")
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
|
||||
req.Reply(true, nil)
|
||||
serverOptions := []sftp.ServerOption{
|
||||
sftp.WithDebug(os.Stderr),
|
||||
}
|
||||
|
@ -232,6 +261,11 @@ func (s *Server) HandleGlobalRequests(sshConn *ssh.ServerConn, reqs <-chan *ssh.
|
|||
for req := range reqs {
|
||||
switch req.Type {
|
||||
case "tcpip-forward":
|
||||
if !s.AllowTcpipForward {
|
||||
s.Logger.Info("tcpip-forward not allowed")
|
||||
req.Reply(false, nil)
|
||||
break
|
||||
}
|
||||
s.handleTcpipForward(sshConn, req)
|
||||
default:
|
||||
// discard
|
||||
|
|
Loading…
Add table
Reference in a new issue