support --unix-socket

This commit is contained in:
Ryo Ota 2023-08-09 08:37:54 +09:00
parent 101f99e80e
commit e4f1e0989d

View file

@ -15,11 +15,12 @@ import (
var flag struct { var flag struct {
//dnsServer string //dnsServer string
showsVersion bool showsVersion bool
sshHost string sshHost string
sshPort uint16 sshPort uint16
sshShell string sshUnixSocket string
sshUsers []string sshShell string
sshUsers []string
} }
type sshUser struct { type sshUser struct {
@ -32,6 +33,8 @@ func init() {
RootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version") RootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version")
RootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") RootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)")
RootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") RootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port")
// NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html)
RootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket")
RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell")
//RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") //RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)")
RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`)
@ -90,13 +93,22 @@ var RootCmd = &cobra.Command{
} }
sshConfig.AddHostKey(pri) sshConfig.AddHostKey(pri)
// TODO: unix socket support var ln net.Listener
address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) if flag.sshUnixSocket == "" {
ln, err := net.Listen("tcp", address) address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort)))
if err != nil { ln, err = net.Listen("tcp", address)
return err if err != nil {
return err
}
logger.Info(fmt.Sprintf("listening on %s...", address))
} else {
ln, err = net.Listen("unix", flag.sshUnixSocket)
if err != nil {
return err
}
logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket))
} }
logger.Info(fmt.Sprintf("listening on %s...", address)) defer ln.Close()
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {