Implement all together
This commit is contained in:
428
main.go
428
main.go
@@ -1,250 +1,236 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/pem"
|
||||
"io"
|
||||
"log"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
"github.com/creack/pty/v2"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/charmbracelet/ssh"
|
||||
"github.com/charmbracelet/wish"
|
||||
"github.com/charmbracelet/wish/activeterm"
|
||||
"github.com/charmbracelet/wish/bubbletea"
|
||||
"github.com/charmbracelet/wish/logging"
|
||||
"github.com/charmbracelet/wish/recover"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if os.Getenv("COMMAND") == "" {
|
||||
log.Fatal("COMMAND environment variable must be set")
|
||||
}
|
||||
config := &ssh.ServerConfig{
|
||||
Config: ssh.Config{
|
||||
KeyExchanges: []string{"mlkem768x25519-sha256", "curve25519-sha256", "ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521", "diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512"},
|
||||
},
|
||||
NoClientAuth: true,
|
||||
}
|
||||
var signer ssh.Signer
|
||||
keyFile := "/app/host_key"
|
||||
if data, err := os.ReadFile(keyFile); err == nil {
|
||||
signer, err = ssh.ParsePrivateKey(data)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to parse existing host key:", err)
|
||||
}
|
||||
} else {
|
||||
_, key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
signer, err = ssh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
block, err := ssh.MarshalPrivateKey(key, "")
|
||||
if err != nil {
|
||||
log.Fatal("Failed to marshal host key:", err)
|
||||
}
|
||||
privateKeyBytes := pem.EncodeToMemory(block)
|
||||
if err := os.WriteFile(keyFile, privateKeyBytes, 0600); err != nil {
|
||||
log.Fatal("Failed to save host key:", err)
|
||||
}
|
||||
}
|
||||
config.AddHostKey(signer)
|
||||
listener, err := net.Listen("tcp", ":22")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Println("SSH server listening on :22")
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
log.Println("Accept error:", err)
|
||||
continue
|
||||
}
|
||||
go handleConn(conn, config)
|
||||
type Theme struct {
|
||||
Bg lipgloss.Color
|
||||
Fg lipgloss.Color
|
||||
Red lipgloss.Color
|
||||
Green lipgloss.Color
|
||||
Yellow lipgloss.Color
|
||||
Blue lipgloss.Color
|
||||
Purple lipgloss.Color
|
||||
Aqua lipgloss.Color
|
||||
Orange lipgloss.Color
|
||||
Gray lipgloss.Color
|
||||
}
|
||||
|
||||
func gruvboxDark() Theme {
|
||||
return Theme{
|
||||
Bg: lipgloss.Color("#282828"),
|
||||
Fg: lipgloss.Color("#ebdbb2"),
|
||||
Red: lipgloss.Color("#cc241d"),
|
||||
Green: lipgloss.Color("#98971a"),
|
||||
Yellow: lipgloss.Color("#d79921"),
|
||||
Blue: lipgloss.Color("#458588"),
|
||||
Purple: lipgloss.Color("#b16286"),
|
||||
Aqua: lipgloss.Color("#689d6a"),
|
||||
Orange: lipgloss.Color("#d65d0e"),
|
||||
Gray: lipgloss.Color("#928374"),
|
||||
}
|
||||
}
|
||||
|
||||
func handleConn(conn net.Conn, config *ssh.ServerConfig) {
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, config)
|
||||
if err != nil {
|
||||
log.Println("ServerConn error:", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if acm, ok := sshConn.Conn.(ssh.AlgorithmsConnMetadata); ok {
|
||||
log.Println("Negotiated KEX:", acm.Algorithms().KeyExchange)
|
||||
}
|
||||
log.Println("New connection from", sshConn.RemoteAddr(), "user", sshConn.User())
|
||||
go ssh.DiscardRequests(reqs)
|
||||
for newChannel := range chans {
|
||||
if newChannel.ChannelType() != "session" {
|
||||
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
|
||||
continue
|
||||
}
|
||||
channel, requests, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
log.Println("Channel accept error:", err)
|
||||
continue
|
||||
}
|
||||
go handleChannel(channel, requests)
|
||||
}
|
||||
sshConn.Wait()
|
||||
type model struct {
|
||||
showWelcome bool
|
||||
focus int
|
||||
showHelp bool
|
||||
theme Theme
|
||||
blink bool
|
||||
blinkCount int
|
||||
}
|
||||
|
||||
func handleChannel(channel ssh.Channel, requests <-chan *ssh.Request) {
|
||||
defer channel.Close()
|
||||
var ptmx *os.File
|
||||
var termWidth, termHeight uint32 = 80, 24
|
||||
var clientTerm string
|
||||
var ptyAllocated bool // Track if PTY is allocated
|
||||
const numBoxes = 4
|
||||
|
||||
for req := range requests {
|
||||
switch req.Type {
|
||||
case "env":
|
||||
// Handle environment requests (e.g., client TERM propagation)
|
||||
if len(req.Payload) >= 7 && string(req.Payload[:7]) == "TERM=\x00" {
|
||||
termLen := int(req.Payload[7])
|
||||
if len(req.Payload) >= 8+termLen {
|
||||
clientTerm = string(req.Payload[8 : 8+termLen])
|
||||
log.Println("Client TERM from env:", clientTerm)
|
||||
}
|
||||
func (m model) Init() tea.Cmd {
|
||||
return tea.Tick(time.Millisecond*70, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
type tickMsg time.Time
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tickMsg:
|
||||
m.blinkCount++
|
||||
m.blink = (m.blinkCount / 3) % 2 == 0
|
||||
return m, tea.Tick(time.Millisecond*70, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
case tea.KeyMsg:
|
||||
if m.showWelcome {
|
||||
m.showWelcome = false
|
||||
return m, nil
|
||||
}
|
||||
if m.showHelp {
|
||||
switch msg.String() {
|
||||
case "q", "esc", "?", "enter", "backspace":
|
||||
m.showHelp = false
|
||||
}
|
||||
req.Reply(true, nil)
|
||||
case "pty-req":
|
||||
// Allocate PTY early on pty-req, before shell/exec
|
||||
if ptyAllocated {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
}
|
||||
if len(req.Payload) >= 4 {
|
||||
termLen := binary.BigEndian.Uint32(req.Payload[0:4])
|
||||
if len(req.Payload) >= int(4+termLen+16) {
|
||||
clientTerm = string(req.Payload[4 : 4+termLen])
|
||||
log.Println("Client TERM from pty-req:", clientTerm)
|
||||
cols := binary.BigEndian.Uint32(req.Payload[4+termLen : 4+termLen+4])
|
||||
rows := binary.BigEndian.Uint32(req.Payload[4+termLen+4 : 4+termLen+8])
|
||||
if cols > 0 {
|
||||
termWidth = cols
|
||||
}
|
||||
if rows > 0 {
|
||||
termHeight = rows
|
||||
}
|
||||
}
|
||||
}
|
||||
// Set default TERM if not provided (TUI-compatible)
|
||||
if clientTerm == "" {
|
||||
clientTerm = "xterm-256color"
|
||||
}
|
||||
command := os.Getenv("COMMAND")
|
||||
if command == "" {
|
||||
command = "/app/tui"
|
||||
}
|
||||
cmd := exec.Command(command)
|
||||
envTerm := "TERM=" + clientTerm
|
||||
cmd.Env = []string{"PATH=/bin", envTerm}
|
||||
cmd.Dir = "/"
|
||||
var err error
|
||||
ptmx, err = pty.StartWithSize(cmd, &pty.Winsize{Cols: uint16(termWidth), Rows: uint16(termHeight)})
|
||||
if err != nil {
|
||||
log.Println("PTY start error:", err)
|
||||
req.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
// Make raw mode on master (server side)
|
||||
if _, err := term.MakeRaw(int(ptmx.Fd())); err != nil {
|
||||
log.Println("MakeRaw master error:", err)
|
||||
}
|
||||
// Note: Slave (cmd side) is already raw via pty.Start, but ensure via setsid if needed
|
||||
ptyAllocated = true
|
||||
// Start I/O bridging immediately
|
||||
go func() {
|
||||
defer func() {
|
||||
ptmx.Close()
|
||||
channel.Close()
|
||||
}()
|
||||
// Bidirectional copy with error handling
|
||||
done := make(chan error, 2)
|
||||
go func() { done <- io.Copy(channel, ptmx) }()
|
||||
go func() { done <- io.Copy(ptmx, channel) }()
|
||||
<-done // Wait for one to finish
|
||||
cmd.Process.Signal(os.Interrupt) // Graceful shutdown
|
||||
<-done
|
||||
}()
|
||||
req.Reply(true, nil)
|
||||
// Wait for cmd to finish after PTY setup
|
||||
go func() {
|
||||
cmd.Wait()
|
||||
channel.SendRequest("exit-status", false, []byte{0}) // Send exit status
|
||||
}()
|
||||
return // PTY session started, no more requests
|
||||
case "window-change":
|
||||
// Handle resizes post-PTY allocation
|
||||
if ptyAllocated && ptmx != nil {
|
||||
width := binary.BigEndian.Uint32(req.Payload)
|
||||
height := binary.BigEndian.Uint32(req.Payload[4:])
|
||||
if width > 0 {
|
||||
termWidth = width
|
||||
}
|
||||
if height > 0 {
|
||||
termHeight = height
|
||||
}
|
||||
pty.Setsize(ptmx, &pty.Winsize{Cols: uint16(termWidth), Rows: uint16(termHeight)})
|
||||
}
|
||||
req.Reply(true, nil)
|
||||
case "shell":
|
||||
// Only handle if no PTY (fallback, but PTY should be allocated first)
|
||||
if !ptyAllocated {
|
||||
req.Reply(true, nil)
|
||||
// Fallback to original shell logic if needed
|
||||
log.Println("Shell requested without PTY")
|
||||
return m, nil
|
||||
}
|
||||
switch msg.String() {
|
||||
case "h", "left":
|
||||
m.focus = (m.focus + 3) % numBoxes
|
||||
case "l", "right":
|
||||
m.focus = (m.focus + 1) % numBoxes
|
||||
case "j", "down":
|
||||
if m.focus < 2 {
|
||||
m.focus += 2
|
||||
} else {
|
||||
req.Reply(false, nil)
|
||||
m.focus -= 2
|
||||
}
|
||||
case "exec":
|
||||
// Handle exec separately if no PTY
|
||||
if ptyAllocated {
|
||||
req.Reply(false, nil)
|
||||
continue
|
||||
case "k", "up":
|
||||
if m.focus >= 2 {
|
||||
m.focus -= 2
|
||||
} else {
|
||||
m.focus += 2
|
||||
}
|
||||
req.Reply(true, nil)
|
||||
command := string(req.Payload[4:])
|
||||
runCommand(channel, command)
|
||||
return
|
||||
default:
|
||||
req.Reply(false, nil)
|
||||
case "?":
|
||||
m.showHelp = true
|
||||
case "q", "esc":
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func runCommand(channel ssh.Channel, command string) {
|
||||
defer channel.Close()
|
||||
cmd := exec.Command("/bin/sh", "-c", command)
|
||||
cmd.Env = []string{"PATH=/bin"}
|
||||
cmd.Dir = "/"
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
log.Println("StdinPipe error:", err)
|
||||
return
|
||||
func (m model) View() string {
|
||||
if m.showWelcome {
|
||||
fg := lipgloss.NewStyle().Foreground(m.theme.Fg).Background(m.theme.Bg)
|
||||
accentFg := lipgloss.NewStyle().Foreground(m.theme.Blue).Background(m.theme.Bg)
|
||||
if m.blink {
|
||||
accentFg = lipgloss.NewStyle().Foreground(m.theme.Bg).Background(m.theme.Bg)
|
||||
}
|
||||
welcome := fg.Render("Welcome to ") +
|
||||
lipgloss.NewStyle().Foreground(m.theme.Orange).Background(m.theme.Bg).Render("dcorral.com") +
|
||||
fg.Render("!") + "\n" +
|
||||
fg.Render("press ") + accentFg.Render("any key") + fg.Render(" to continue")
|
||||
return welcome
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
log.Println("StdoutPipe error:", err)
|
||||
return
|
||||
if m.showHelp {
|
||||
return lipgloss.NewStyle().Foreground(m.theme.Fg).Background(m.theme.Bg).Render(
|
||||
"Navigation:\n" +
|
||||
"h / ←: Move left\n" +
|
||||
"l / →: Move right\n" +
|
||||
"j / ↓: Move down\n" +
|
||||
"k / ↑: Move up\n" +
|
||||
"? : Show help\n" +
|
||||
"q / Esc: Quit",
|
||||
)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
log.Println("StderrPipe error:", err)
|
||||
return
|
||||
// Render the boxes
|
||||
boxes := make([]string, numBoxes)
|
||||
for i := range boxes {
|
||||
var style lipgloss.Style
|
||||
if i == m.focus {
|
||||
style = lipgloss.NewStyle().Background(m.theme.Blue).Foreground(m.theme.Bg).Border(lipgloss.NormalBorder()).Padding(0, 1)
|
||||
} else {
|
||||
style = lipgloss.NewStyle().Background(m.theme.Bg).Foreground(m.theme.Fg).Border(lipgloss.NormalBorder()).Padding(0, 1)
|
||||
}
|
||||
boxes[i] = style.Render("Box " + string(rune('0'+i)))
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
log.Println("Start error:", err)
|
||||
return
|
||||
}
|
||||
go func() { io.Copy(stdin, channel); stdin.Close() }()
|
||||
go func() { io.Copy(channel, stdout); stdout.Close() }()
|
||||
go func() { io.Copy(channel, stderr); stderr.Close() }()
|
||||
cmd.Wait()
|
||||
grid := lipgloss.JoinHorizontal(lipgloss.Top, boxes[0], " ", boxes[1]) + "\n" +
|
||||
lipgloss.JoinHorizontal(lipgloss.Top, boxes[2], " ", boxes[3]) + "\n" +
|
||||
lipgloss.NewStyle().Foreground(m.theme.Fg).Background(m.theme.Bg).Render("Use hjkl or arrows to navigate, ? for help, q to quit")
|
||||
return grid
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Generate host key
|
||||
_, key, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
log.Error("Failed to generate host key", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
block, err := gossh.MarshalPrivateKey(key, "")
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal host key", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
privateKeyBytes := pem.EncodeToMemory(block)
|
||||
log.Info("Generated host key")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
signalChan := make(chan os.Signal, 1)
|
||||
signal.Notify(signalChan, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signalChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
port := os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = "22"
|
||||
}
|
||||
|
||||
s, err := wish.NewServer(
|
||||
wish.WithAddress(net.JoinHostPort("0.0.0.0", port)),
|
||||
wish.WithHostKeyPEM(privateKeyBytes),
|
||||
wish.WithMiddleware(
|
||||
recover.Middleware(
|
||||
bubbletea.Middleware(teaHandler),
|
||||
activeterm.Middleware(),
|
||||
logging.Middleware(),
|
||||
),
|
||||
),
|
||||
wish.WithPublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
return true
|
||||
}),
|
||||
wish.WithKeyboardInteractiveAuth(
|
||||
func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
|
||||
return true
|
||||
},
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error("Could not start server", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Info("Starting SSH server", "port", port)
|
||||
go func() {
|
||||
if err = s.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
||||
log.Error("Could not start server", "error", err)
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
s.Shutdown(ctx)
|
||||
slog.Info("Shutting down server")
|
||||
}
|
||||
|
||||
func teaHandler(s ssh.Session) (tea.Model, []tea.ProgramOption) {
|
||||
return model{
|
||||
showWelcome: true,
|
||||
focus: 0,
|
||||
showHelp: false,
|
||||
theme: gruvboxDark(),
|
||||
blink: false,
|
||||
blinkCount: 0,
|
||||
}, []tea.ProgramOption{tea.WithAltScreen()}
|
||||
}
|
||||
Reference in New Issue
Block a user