diff --git a/cmd/root.go b/cmd/root.go index 18c8db3..c964a03 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -15,11 +15,12 @@ import ( var flag struct { //dnsServer string - showsVersion bool - sshHost string - sshPort uint16 - sshShell string - sshUsers []string + showsVersion bool + sshHost string + sshPort uint16 + sshUnixSocket string + sshShell string + sshUsers []string } type sshUser struct { @@ -32,6 +33,8 @@ func init() { RootCmd.PersistentFlags().BoolVarP(&flag.showsVersion, "version", "v", false, "show version") RootCmd.PersistentFlags().StringVarP(&flag.sshHost, "host", "", "", "SSH server host (e.g. 127.0.0.1)") RootCmd.PersistentFlags().Uint16VarP(&flag.sshPort, "port", "p", 2222, "SSH server port") + // NOTE: long name 'unix-socket' is from curl (ref: https://curl.se/docs/manpage.html) + RootCmd.PersistentFlags().StringVarP(&flag.sshUnixSocket, "unix-socket", "", "", "Unix-domain socket") RootCmd.PersistentFlags().StringVarP(&flag.sshShell, "shell", "", "", "Shell") //RootCmd.PersistentFlags().StringVar(&flag.dnsServer, "dns-server", "", "DNS server (e.g. 1.1.1.1:53)") RootCmd.PersistentFlags().StringArrayVarP(&flag.sshUsers, "user", "", nil, `SSH user name (e.g. "john:mypassword")`) @@ -90,13 +93,22 @@ var RootCmd = &cobra.Command{ } sshConfig.AddHostKey(pri) - // TODO: unix socket support - address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) - ln, err := net.Listen("tcp", address) - if err != nil { - return err + var ln net.Listener + if flag.sshUnixSocket == "" { + address := net.JoinHostPort(flag.sshHost, strconv.Itoa(int(flag.sshPort))) + ln, err = net.Listen("tcp", address) + if err != nil { + return err + } + logger.Info(fmt.Sprintf("listening on %s...", address)) + } else { + ln, err = net.Listen("unix", flag.sshUnixSocket) + if err != nil { + return err + } + logger.Info(fmt.Sprintf("listening on %s...", flag.sshUnixSocket)) } - logger.Info(fmt.Sprintf("listening on %s...", address)) + defer ln.Close() for { conn, err := ln.Accept() if err != nil {