Adds persistent storage using modernc.org/sqlite (pure Go). Login attempts are deduplicated by (username, password, ip) with counts. Sessions and session logs are tracked with UUID IDs. Includes embedded SQL migrations, configurable retention with background pruning, and an in-memory store for tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
263 lines
6.2 KiB
Go
263 lines
6.2 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"git.t-juice.club/torjus/oubliette/internal/auth"
|
|
"git.t-juice.club/torjus/oubliette/internal/config"
|
|
"git.t-juice.club/torjus/oubliette/internal/storage"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const sessionTimeout = 30 * time.Second
|
|
|
|
type Server struct {
|
|
cfg config.Config
|
|
store storage.Store
|
|
authenticator *auth.Authenticator
|
|
sshConfig *ssh.ServerConfig
|
|
logger *slog.Logger
|
|
connSem chan struct{} // semaphore limiting concurrent connections
|
|
}
|
|
|
|
func New(cfg config.Config, store storage.Store, logger *slog.Logger) (*Server, error) {
|
|
s := &Server{
|
|
cfg: cfg,
|
|
store: store,
|
|
authenticator: auth.NewAuthenticator(cfg.Auth),
|
|
logger: logger,
|
|
connSem: make(chan struct{}, cfg.SSH.MaxConnections),
|
|
}
|
|
|
|
hostKey, err := loadOrGenerateHostKey(cfg.SSH.HostKeyPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("host key: %w", err)
|
|
}
|
|
|
|
s.sshConfig = &ssh.ServerConfig{
|
|
PasswordCallback: s.passwordCallback,
|
|
ServerVersion: "SSH-2.0-OpenSSH_8.9p1 Ubuntu-3ubuntu0.6",
|
|
}
|
|
s.sshConfig.AddHostKey(hostKey)
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Server) ListenAndServe(ctx context.Context) error {
|
|
listener, err := net.Listen("tcp", s.cfg.SSH.ListenAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen: %w", err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
s.logger.Info("SSH server listening", "addr", s.cfg.SSH.ListenAddr)
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
listener.Close()
|
|
}()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil
|
|
}
|
|
s.logger.Error("accept error", "err", err)
|
|
continue
|
|
}
|
|
|
|
// Enforce max concurrent connections.
|
|
select {
|
|
case s.connSem <- struct{}{}:
|
|
go func() {
|
|
defer func() { <-s.connSem }()
|
|
s.handleConn(conn)
|
|
}()
|
|
default:
|
|
s.logger.Warn("max connections reached, rejecting", "remote_addr", conn.RemoteAddr())
|
|
conn.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleConn(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
if err != nil {
|
|
s.logger.Debug("SSH handshake failed", "remote_addr", conn.RemoteAddr(), "err", err)
|
|
return
|
|
}
|
|
defer sshConn.Close()
|
|
|
|
s.logger.Info("SSH connection established",
|
|
"remote_addr", sshConn.RemoteAddr(),
|
|
"user", sshConn.User(),
|
|
)
|
|
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
for newChan := range chans {
|
|
if newChan.ChannelType() != "session" {
|
|
newChan.Reject(ssh.UnknownChannelType, "unknown channel type")
|
|
continue
|
|
}
|
|
|
|
channel, requests, err := newChan.Accept()
|
|
if err != nil {
|
|
s.logger.Error("channel accept error", "err", err)
|
|
return
|
|
}
|
|
|
|
go s.handleSession(channel, requests, sshConn)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleSession(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
|
defer channel.Close()
|
|
|
|
ip := extractIP(conn.RemoteAddr())
|
|
sessionID, err := s.store.CreateSession(context.Background(), ip, conn.User(), "")
|
|
if err != nil {
|
|
s.logger.Error("failed to create session", "err", err)
|
|
} else {
|
|
defer func() {
|
|
if err := s.store.EndSession(context.Background(), sessionID, time.Now()); err != nil {
|
|
s.logger.Error("failed to end session", "err", err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Handle session requests (pty-req, shell, etc.)
|
|
go func() {
|
|
for req := range requests {
|
|
switch req.Type {
|
|
case "pty-req", "shell":
|
|
if req.WantReply {
|
|
req.Reply(true, nil)
|
|
}
|
|
default:
|
|
if req.WantReply {
|
|
req.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Write a fake banner.
|
|
fmt.Fprint(channel, "Welcome to Ubuntu 22.04.3 LTS (GNU/Linux 5.15.0-89-generic x86_64)\r\n\r\n")
|
|
fmt.Fprintf(channel, "Last login: %s from 10.0.0.1\r\n", time.Now().Add(-2*time.Hour).Format("Mon Jan 2 15:04:05 2006"))
|
|
fmt.Fprintf(channel, "%s@ubuntu:~$ ", conn.User())
|
|
|
|
// Hold connection open until timeout or client disconnect.
|
|
timer := time.NewTimer(sessionTimeout)
|
|
defer timer.Stop()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
buf := make([]byte, 256)
|
|
for {
|
|
_, err := channel.Read(buf)
|
|
if err != nil {
|
|
close(done)
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-timer.C:
|
|
s.logger.Info("session timed out", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
|
case <-done:
|
|
s.logger.Info("session closed by client", "remote_addr", conn.RemoteAddr(), "user", conn.User())
|
|
}
|
|
}
|
|
|
|
func (s *Server) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
ip := extractIP(conn.RemoteAddr())
|
|
d := s.authenticator.Authenticate(ip, conn.User(), string(password))
|
|
|
|
s.logger.Info("auth attempt",
|
|
"remote_addr", conn.RemoteAddr(),
|
|
"username", conn.User(),
|
|
"accepted", d.Accepted,
|
|
"reason", d.Reason,
|
|
)
|
|
|
|
if err := s.store.RecordLoginAttempt(context.Background(), conn.User(), string(password), ip); err != nil {
|
|
s.logger.Error("failed to record login attempt", "err", err)
|
|
}
|
|
|
|
if d.Accepted {
|
|
return nil, nil
|
|
}
|
|
return nil, fmt.Errorf("rejected")
|
|
}
|
|
|
|
func extractIP(addr net.Addr) string {
|
|
host, _, err := net.SplitHostPort(addr.String())
|
|
if err != nil {
|
|
// Might not have a port, try using the string directly.
|
|
return addr.String()
|
|
}
|
|
// Normalize IPv4-mapped IPv6 addresses.
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
return host
|
|
}
|
|
if v4 := ip.To4(); v4 != nil {
|
|
return v4.String()
|
|
}
|
|
return ip.String()
|
|
}
|
|
|
|
func loadOrGenerateHostKey(path string) (ssh.Signer, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err == nil {
|
|
signer, err := ssh.ParsePrivateKey(data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing host key: %w", err)
|
|
}
|
|
return signer, nil
|
|
}
|
|
|
|
if !errors.Is(err, os.ErrNotExist) {
|
|
return nil, fmt.Errorf("reading host key: %w", err)
|
|
}
|
|
|
|
// Generate new Ed25519 key.
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating key: %w", err)
|
|
}
|
|
|
|
privBytes, err := ssh.MarshalPrivateKey(priv, "")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling key: %w", err)
|
|
}
|
|
|
|
pemData := pem.EncodeToMemory(privBytes)
|
|
if err := os.WriteFile(path, pemData, 0600); err != nil {
|
|
return nil, fmt.Errorf("writing host key: %w", err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(pemData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing generated key: %w", err)
|
|
}
|
|
|
|
slog.Info("generated new host key", "path", path)
|
|
return signer, nil
|
|
}
|
|
|