251 lines
6.7 KiB
Go
251 lines
6.7 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/binary"
|
|
"encoding/pem"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"os/exec"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/term"
|
|
"github.com/creack/pty/v2"
|
|
)
|
|
|
|
func main() {
|
|
if os.Getenv("COMMAND") == "" {
|
|
log.Fatal("COMMAND environment variable must be set")
|
|
}
|
|
config := &ssh.ServerConfig{
|
|
Config: ssh.Config{
|
|
KeyExchanges: []string{"mlkem768x25519-sha256", "curve25519-sha256", "ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521", "diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512"},
|
|
},
|
|
NoClientAuth: true,
|
|
}
|
|
var signer ssh.Signer
|
|
keyFile := "/app/host_key"
|
|
if data, err := os.ReadFile(keyFile); err == nil {
|
|
signer, err = ssh.ParsePrivateKey(data)
|
|
if err != nil {
|
|
log.Fatal("Failed to parse existing host key:", err)
|
|
}
|
|
} else {
|
|
_, key, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
signer, err = ssh.NewSignerFromKey(key)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
block, err := ssh.MarshalPrivateKey(key, "")
|
|
if err != nil {
|
|
log.Fatal("Failed to marshal host key:", err)
|
|
}
|
|
privateKeyBytes := pem.EncodeToMemory(block)
|
|
if err := os.WriteFile(keyFile, privateKeyBytes, 0600); err != nil {
|
|
log.Fatal("Failed to save host key:", err)
|
|
}
|
|
}
|
|
config.AddHostKey(signer)
|
|
listener, err := net.Listen("tcp", ":22")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
log.Println("SSH server listening on :22")
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Println("Accept error:", err)
|
|
continue
|
|
}
|
|
go handleConn(conn, config)
|
|
}
|
|
}
|
|
|
|
func handleConn(conn net.Conn, config *ssh.ServerConfig) {
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, config)
|
|
if err != nil {
|
|
log.Println("ServerConn error:", err)
|
|
conn.Close()
|
|
return
|
|
}
|
|
if acm, ok := sshConn.Conn.(ssh.AlgorithmsConnMetadata); ok {
|
|
log.Println("Negotiated KEX:", acm.Algorithms().KeyExchange)
|
|
}
|
|
log.Println("New connection from", sshConn.RemoteAddr(), "user", sshConn.User())
|
|
go ssh.DiscardRequests(reqs)
|
|
for newChannel := range chans {
|
|
if newChannel.ChannelType() != "session" {
|
|
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
continue
|
|
}
|
|
channel, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
log.Println("Channel accept error:", err)
|
|
continue
|
|
}
|
|
go handleChannel(channel, requests)
|
|
}
|
|
sshConn.Wait()
|
|
}
|
|
|
|
func handleChannel(channel ssh.Channel, requests <-chan *ssh.Request) {
|
|
defer channel.Close()
|
|
var ptmx *os.File
|
|
var termWidth, termHeight uint32 = 80, 24
|
|
var clientTerm string
|
|
var ptyAllocated bool // Track if PTY is allocated
|
|
|
|
for req := range requests {
|
|
switch req.Type {
|
|
case "env":
|
|
// Handle environment requests (e.g., client TERM propagation)
|
|
if len(req.Payload) >= 7 && string(req.Payload[:7]) == "TERM=\x00" {
|
|
termLen := int(req.Payload[7])
|
|
if len(req.Payload) >= 8+termLen {
|
|
clientTerm = string(req.Payload[8 : 8+termLen])
|
|
log.Println("Client TERM from env:", clientTerm)
|
|
}
|
|
}
|
|
req.Reply(true, nil)
|
|
case "pty-req":
|
|
// Allocate PTY early on pty-req, before shell/exec
|
|
if ptyAllocated {
|
|
req.Reply(false, nil)
|
|
continue
|
|
}
|
|
if len(req.Payload) >= 4 {
|
|
termLen := binary.BigEndian.Uint32(req.Payload[0:4])
|
|
if len(req.Payload) >= int(4+termLen+16) {
|
|
clientTerm = string(req.Payload[4 : 4+termLen])
|
|
log.Println("Client TERM from pty-req:", clientTerm)
|
|
cols := binary.BigEndian.Uint32(req.Payload[4+termLen : 4+termLen+4])
|
|
rows := binary.BigEndian.Uint32(req.Payload[4+termLen+4 : 4+termLen+8])
|
|
if cols > 0 {
|
|
termWidth = cols
|
|
}
|
|
if rows > 0 {
|
|
termHeight = rows
|
|
}
|
|
}
|
|
}
|
|
// Set default TERM if not provided (TUI-compatible)
|
|
if clientTerm == "" {
|
|
clientTerm = "xterm-256color"
|
|
}
|
|
command := os.Getenv("COMMAND")
|
|
if command == "" {
|
|
command = "/app/tui"
|
|
}
|
|
cmd := exec.Command(command)
|
|
envTerm := "TERM=" + clientTerm
|
|
cmd.Env = []string{"PATH=/bin", envTerm}
|
|
cmd.Dir = "/"
|
|
var err error
|
|
ptmx, err = pty.StartWithSize(cmd, &pty.Winsize{Cols: uint16(termWidth), Rows: uint16(termHeight)})
|
|
if err != nil {
|
|
log.Println("PTY start error:", err)
|
|
req.Reply(false, nil)
|
|
return
|
|
}
|
|
// Make raw mode on master (server side)
|
|
if _, err := term.MakeRaw(int(ptmx.Fd())); err != nil {
|
|
log.Println("MakeRaw master error:", err)
|
|
}
|
|
// Note: Slave (cmd side) is already raw via pty.Start, but ensure via setsid if needed
|
|
ptyAllocated = true
|
|
// Start I/O bridging immediately
|
|
go func() {
|
|
defer func() {
|
|
ptmx.Close()
|
|
channel.Close()
|
|
}()
|
|
// Bidirectional copy with error handling
|
|
done := make(chan error, 2)
|
|
go func() { done <- io.Copy(channel, ptmx) }()
|
|
go func() { done <- io.Copy(ptmx, channel) }()
|
|
<-done // Wait for one to finish
|
|
cmd.Process.Signal(os.Interrupt) // Graceful shutdown
|
|
<-done
|
|
}()
|
|
req.Reply(true, nil)
|
|
// Wait for cmd to finish after PTY setup
|
|
go func() {
|
|
cmd.Wait()
|
|
channel.SendRequest("exit-status", false, []byte{0}) // Send exit status
|
|
}()
|
|
return // PTY session started, no more requests
|
|
case "window-change":
|
|
// Handle resizes post-PTY allocation
|
|
if ptyAllocated && ptmx != nil {
|
|
width := binary.BigEndian.Uint32(req.Payload)
|
|
height := binary.BigEndian.Uint32(req.Payload[4:])
|
|
if width > 0 {
|
|
termWidth = width
|
|
}
|
|
if height > 0 {
|
|
termHeight = height
|
|
}
|
|
pty.Setsize(ptmx, &pty.Winsize{Cols: uint16(termWidth), Rows: uint16(termHeight)})
|
|
}
|
|
req.Reply(true, nil)
|
|
case "shell":
|
|
// Only handle if no PTY (fallback, but PTY should be allocated first)
|
|
if !ptyAllocated {
|
|
req.Reply(true, nil)
|
|
// Fallback to original shell logic if needed
|
|
log.Println("Shell requested without PTY")
|
|
} else {
|
|
req.Reply(false, nil)
|
|
}
|
|
case "exec":
|
|
// Handle exec separately if no PTY
|
|
if ptyAllocated {
|
|
req.Reply(false, nil)
|
|
continue
|
|
}
|
|
req.Reply(true, nil)
|
|
command := string(req.Payload[4:])
|
|
runCommand(channel, command)
|
|
return
|
|
default:
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
func runCommand(channel ssh.Channel, command string) {
|
|
defer channel.Close()
|
|
cmd := exec.Command("/bin/sh", "-c", command)
|
|
cmd.Env = []string{"PATH=/bin"}
|
|
cmd.Dir = "/"
|
|
stdin, err := cmd.StdinPipe()
|
|
if err != nil {
|
|
log.Println("StdinPipe error:", err)
|
|
return
|
|
}
|
|
stdout, err := cmd.StdoutPipe()
|
|
if err != nil {
|
|
log.Println("StdoutPipe error:", err)
|
|
return
|
|
}
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
log.Println("StderrPipe error:", err)
|
|
return
|
|
}
|
|
if err := cmd.Start(); err != nil {
|
|
log.Println("Start error:", err)
|
|
return
|
|
}
|
|
go func() { io.Copy(stdin, channel); stdin.Close() }()
|
|
go func() { io.Copy(channel, stdout); stdout.Close() }()
|
|
go func() { io.Copy(channel, stderr); stderr.Close() }()
|
|
cmd.Wait()
|
|
}
|
|
|