package server import ( "context" "crypto/ed25519" "crypto/rand" "encoding/pem" "errors" "fmt" "log/slog" "net" "os" "time" "git.t-juice.club/torjus/oubliette/internal/auth" "git.t-juice.club/torjus/oubliette/internal/config" "git.t-juice.club/torjus/oubliette/internal/storage" "golang.org/x/crypto/ssh" ) const sessionTimeout = 30 * time.Second type Server struct { cfg config.Config store storage.Store authenticator *auth.Authenticator sshConfig *ssh.ServerConfig logger *slog.Logger connSem chan struct{} // semaphore limiting concurrent connections } func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) { s := &Server{ cfg: cfg, store: store, authenticator: auth.NewAuthenticator(cfg.Auth), logger: logger, connSem: make(chan struct{}, cfg.SSH.MaxConnections), } hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath) if err != nil { return nil, fmt.Errorf("host key: %w", err) } s.sshConfig = &ssh.ServerConfig{ PasswordCallback: s.passwordCallback, ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6", } s.sshConfig.AddHostKey(hostKey) return s, nil } func (s *Server) ListenAndServe(ctx context.Context) error { listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr) if err != nil { return fmt.Errorf("listen: %w", err) } defer listener.Close() s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr) go func() { <-ctx.Done() listener.Close() }() for { conn, err := listener.Accept() if err != nil { if ctx.Err() != nil { return nil } s.logger.Error("accept error", "err", err) continue } // Enforce max concurrent connections. select { case s.connSem <- struct{}{}: go func() { defer func() { <-s.connSem }() s.handleConn(conn) }() default: s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr()) conn.Close() } } } func (s *Server) handleConn(conn net.Conn) { defer conn.Close() sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig) if err != nil { s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err) return } defer sshConn.Close() s.logger.Info("SSH connection established", "remote_addr", sshConn.RemoteAddr(), "user", sshConn.User(), ) go ssh.DiscardRequests(reqs) for newChan := range chans { if newChan.ChannelType() != "session" { newChan.Reject(ssh.UnknownChannelType, "unknown channel type") continue } channel, requests, err := newChan.Accept() if err != nil { s.logger.Error("channel accept error", "err", err) return } go s.handleSession(channel, requests, sshConn) } } func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { defer channel.Close() ip := extractIP(conn.RemoteAddr()) sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "") if err != nil { s.logger.Error("failed to create session", "err", err) } else { defer func() { if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil { s.logger.Error("failed to end session", "err", err) } }() } // Handle session requests (pty-req, shell, etc.) go func() { for req := range requests { switch req.Type { case "pty-req", "shell": if req.WantReply { req.Reply(true, nil) } default: if req.WantReply { req.Reply(false, nil) } } } }() // Write a fake banner. fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n") fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006")) fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User()) // Hold connection open until timeout or client disconnect. timer := time.NewTimer(sessionTimeout) defer timer.Stop() done := make(chan struct{}) go func() { buf := make([]byte, 256) for { _, err := channel.Read(buf) if err != nil { close(done) return } } }() select { case <-timer.C: s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User()) case <-done: s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User()) } } func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { ip := extractIP(conn.RemoteAddr()) d := s.authenticator.Authenticate(ip, conn.User(), string(password)) s.logger.Info("auth attempt", "remote_addr", conn.RemoteAddr(), "username", conn.User(), "accepted", d.Accepted, "reason", d.Reason, ) if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil { s.logger.Error("failed to record login attempt", "err", err) } if d.Accepted { return nil, nil } return nil, fmt.Errorf("rejected") } func extractIP(addr net.Addr) string { host, _, err := net.SplitHostPort(addr.String()) if err != nil { // Might not have a port, try using the string directly. return addr.String() } // Normalize IPv4-mapped IPv6 addresses. ip := net.ParseIP(host) if ip == nil { return host } if v4 := ip.To4(); v4 != nil { return v4.String() } return ip.String() } func loadOrGenerateHostKey(path string) (ssh.Signer, error) { data, err := os.ReadFile(path) if err == nil { signer, err := ssh.ParsePrivateKey(data) if err != nil { return nil, fmt.Errorf("parsing host key: %w", err) } return signer, nil } if !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("reading host key: %w", err) } // Generate new Ed25519 key. _, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, fmt.Errorf("generating key: %w", err) } privBytes, err := ssh.MarshalPrivateKey(priv, "") if err != nil { return nil, fmt.Errorf("marshaling key: %w", err) } pemData := pem.EncodeToMemory(privBytes) if err := os.WriteFile(path, pemData, 0600); err != nil { return nil, fmt.Errorf("writing host key: %w", err) } signer, err := ssh.ParsePrivateKey(pemData) if err != nil { return nil, fmt.Errorf("parsing generated key: %w", err) } slog.Info("generated new host key", "path", path) return signer, nil }