Compare commits
13 Commits
0a4eac188a
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
1b28f10ca8
|
|||
|
664e79fce6
|
|||
|
c74313c195
|
|||
|
9783ae5865
|
|||
|
62de222488
|
|||
| c9d143d84b | |||
|
d18a904ed5
|
|||
|
cb7be28f42
|
|||
|
0908b43724
|
|||
|
52310f588d
|
|||
|
b52216bd2f
|
|||
|
2bc83a17dd
|
|||
|
faf6e2abd7
|
29
PLAN.md
29
PLAN.md
@@ -179,7 +179,12 @@ Goal: Add the entertaining shell implementations.
|
||||
- DDL/DML acknowledgments (CREATE TABLE, INSERT, UPDATE, DELETE, etc.)
|
||||
- Username-to-shell routing: configurable `[shell.username_routes]` maps usernames to shells
|
||||
|
||||
### 3.7 Other Shell Ideas (Future)
|
||||
### 3.7 Roomba Shell ✅
|
||||
- iRobot Roomba j7+ vacuum robot interface
|
||||
- Status, cleaning, scheduling, diagnostics, floor map
|
||||
- Humorous history entries (cat encounters, sock tangles, sticky substances)
|
||||
|
||||
### 3.8 Other Shell Ideas (Future)
|
||||
- **Nuclear launch terminal:** "ENTER LAUNCH AUTHORIZATION CODE"
|
||||
- **ELIZA therapist:** every response is a therapy question
|
||||
- **Pizza ordering terminal:** "Welcome to PizzaNet v2.3"
|
||||
@@ -191,11 +196,11 @@ Goal: Add the entertaining shell implementations.
|
||||
|
||||
Goal: Make the web UI great and add operational niceties.
|
||||
|
||||
### 4.1 Enhanced Web UI
|
||||
- GeoIP lookups and world map visualization of attack sources
|
||||
- Charts: attempts over time, hourly patterns, credential trends
|
||||
- Session detail view with full command log
|
||||
- Filtering and search
|
||||
### 4.1 Enhanced Web UI ✅
|
||||
- GeoIP lookups and world map visualization of attack sources ✅
|
||||
- Charts: attempts over time, hourly patterns, credential trends ✅
|
||||
- Session detail view with full command log ✅
|
||||
- Filtering and search ✅
|
||||
|
||||
### 4.2 Operational ✅
|
||||
- Prometheus metrics endpoint ✅
|
||||
@@ -209,15 +214,15 @@ Goal: Make the web UI great and add operational niceties.
|
||||
- Store country/city with each attempt ✅
|
||||
- Aggregate stats by country ✅
|
||||
|
||||
### 4.4 Capture SSH Exec Commands
|
||||
### 4.4 Capture SSH Exec Commands ✅
|
||||
Many bots send a command directly via `ssh user@host <command>` (an SSH "exec" request) rather than requesting an interactive shell. Currently these are rejected and the command is lost. We should capture them.
|
||||
|
||||
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`)
|
||||
- Parse the command string from the exec payload
|
||||
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration
|
||||
- Store the command on the session record before closing the channel
|
||||
- Handle `"exec"` request type in the server's request loop (alongside `"pty-req"` and `"shell"`) ✅
|
||||
- Parse the command string from the exec payload ✅
|
||||
- Add an `exec_command` column (nullable) to the `sessions` table via a new migration ✅
|
||||
- Store the command on the session record before closing the channel ✅
|
||||
- Optionally return plausible fake output for common commands (e.g. `uname`, `id`, `cat /etc/passwd`) to encourage further interaction
|
||||
- Surface exec commands in the web UI (session detail view)
|
||||
- Surface exec commands in the web UI (session detail view) ✅
|
||||
|
||||
#### 4.4.1 Fake Exec Output
|
||||
Return plausible fake output for exec commands to encourage bots to interact further.
|
||||
|
||||
@@ -34,7 +34,7 @@ Key settings:
|
||||
- `auth.accept_after` — accept login after N failures per IP (default `10`)
|
||||
- `auth.credential_ttl` — how long to remember accepted credentials (default `24h`)
|
||||
- `auth.static_credentials` — always-accepted username/password pairs (optional `shell` field routes to a specific shell)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal)
|
||||
- Available shells: `bash` (fake Linux shell), `fridge` (Samsung Smart Fridge OS), `banking` (80s-style bank terminal TUI), `adventure` (Zork-style text adventure dungeon), `cisco` (Cisco IOS CLI with mode state machine and command abbreviation), `psql` (PostgreSQL psql interactive terminal), `roomba` (iRobot Roomba vacuum robot), `tetris` (Tetris game TUI)
|
||||
- `shell.username_routes` — map usernames to specific shells (e.g. `postgres = "psql"`); credential-specific shell overrides take priority
|
||||
- `storage.db_path` — SQLite database path (default `oubliette.db`)
|
||||
- `storage.retention_days` — auto-prune records older than N days (default `90`)
|
||||
|
||||
@@ -13,14 +13,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/server"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"git.t-juice.club/torjus/oubliette/internal/web"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/server"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/web"
|
||||
)
|
||||
|
||||
const Version = "0.14.0"
|
||||
const Version = "0.18.0"
|
||||
|
||||
func main() {
|
||||
if err := run(); err != nil {
|
||||
@@ -76,12 +76,13 @@ func run() error {
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
go storage.RunRetention(ctx, store, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
m := metrics.New(Version)
|
||||
m.RegisterStoreCollector(store)
|
||||
instrumentedStore := storage.NewInstrumentedStore(store, m.StorageQueryDuration, m.StorageQueryErrors)
|
||||
m.RegisterStoreCollector(instrumentedStore)
|
||||
|
||||
srv, err := server.New(*cfg, store, logger, m)
|
||||
go storage.RunRetention(ctx, instrumentedStore, cfg.Storage.RetentionDays, cfg.Storage.RetentionIntervalDuration, logger)
|
||||
|
||||
srv, err := server.New(*cfg, instrumentedStore, logger, m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create server: %w", err)
|
||||
}
|
||||
@@ -95,7 +96,7 @@ func run() error {
|
||||
metricsHandler = m.Handler()
|
||||
}
|
||||
|
||||
webHandler, err := web.NewServer(store, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
webHandler, err := web.NewServer(instrumentedStore, logger.With("component", "web"), metricsHandler, cfg.Web.MetricsToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create web server: %w", err)
|
||||
}
|
||||
|
||||
4
go.mod
4
go.mod
@@ -1,4 +1,4 @@
|
||||
module git.t-juice.club/torjus/oubliette
|
||||
module code.t-juice.club/torjus/oubliette
|
||||
|
||||
go 1.25.5
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/oschwald/maxminddb-golang v1.13.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/prometheus/client_model v0.6.2
|
||||
golang.org/x/crypto v0.48.0
|
||||
modernc.org/sqlite v1.45.0
|
||||
)
|
||||
@@ -33,7 +34,6 @@ require (
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func newTestAuth(acceptAfter int, ttl time.Duration, statics ...config.Credential) *Authenticator {
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -25,6 +25,8 @@ type Metrics struct {
|
||||
SessionDuration prometheus.Histogram
|
||||
ExecCommandsTotal prometheus.Counter
|
||||
BuildInfo *prometheus.GaugeVec
|
||||
StorageQueryDuration *prometheus.HistogramVec
|
||||
StorageQueryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// New creates a new Metrics instance with all collectors registered.
|
||||
@@ -79,6 +81,15 @@ func New(version string) *Metrics {
|
||||
Name: "oubliette_build_info",
|
||||
Help: "Build information. Always 1.",
|
||||
}, []string{"version"}),
|
||||
StorageQueryDuration: prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "oubliette_storage_query_duration_seconds",
|
||||
Help: "Duration of storage query calls in seconds.",
|
||||
Buckets: []float64{0.001, 0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
}, []string{"method"}),
|
||||
StorageQueryErrors: prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "oubliette_storage_query_errors_total",
|
||||
Help: "Total storage query errors.",
|
||||
}, []string{"method"}),
|
||||
}
|
||||
|
||||
reg.MustRegister(
|
||||
@@ -95,6 +106,8 @@ func New(version string) *Metrics {
|
||||
m.SessionDuration,
|
||||
m.ExecCommandsTotal,
|
||||
m.BuildInfo,
|
||||
m.StorageQueryDuration,
|
||||
m.StorageQueryErrors,
|
||||
)
|
||||
|
||||
m.BuildInfo.WithLabelValues(version).Set(1)
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
// Event types.
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
)
|
||||
|
||||
func testSession() SessionInfo {
|
||||
|
||||
@@ -12,20 +12,22 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"git.t-juice.club/torjus/oubliette/internal/geoip"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "git.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/detection"
|
||||
"code.t-juice.club/torjus/oubliette/internal/geoip"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/notify"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/adventure"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/banking"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/bash"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/cisco"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/fridge"
|
||||
psqlshell "code.t-juice.club/torjus/oubliette/internal/shell/psql"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/roomba"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell/tetris"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
@@ -62,6 +64,12 @@ func New(cfg config.Config, store storage.Store, logger *slog.Logger, m *metrics
|
||||
if err := registry.Register(psqlshell.NewPsqlShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering psql shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(roomba.NewRoombaShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering roomba shell: %w", err)
|
||||
}
|
||||
if err := registry.Register(tetris.NewTetrisShell(), 1); err != nil {
|
||||
return nil, fmt.Errorf("registering tetris shell: %w", err)
|
||||
}
|
||||
|
||||
geo, err := geoip.New()
|
||||
if err != nil {
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"git.t-juice.club/torjus/oubliette/internal/config"
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/auth"
|
||||
"code.t-juice.club/torjus/oubliette/internal/config"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// EventRecorder buffers I/O events in memory and periodically flushes them to
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func TestEventRecorderFlush(t *testing.T) {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
type rwCloser struct {
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
463
internal/shell/roomba/roomba.go
Normal file
463
internal/shell/roomba/roomba.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package roomba
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 5 * time.Minute
|
||||
|
||||
// RoombaShell emulates an iRobot Roomba vacuum robot interface.
|
||||
type RoombaShell struct{}
|
||||
|
||||
// NewRoombaShell returns a new RoombaShell instance.
|
||||
func NewRoombaShell() *RoombaShell {
|
||||
return &RoombaShell{}
|
||||
}
|
||||
|
||||
func (r *RoombaShell) Name() string { return "roomba" }
|
||||
func (r *RoombaShell) Description() string { return "iRobot Roomba shell emulator" }
|
||||
|
||||
func (r *RoombaShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
state := newRoombaState()
|
||||
|
||||
banner := strings.ReplaceAll(bootBanner(), "\n", "\r\n")
|
||||
fmt.Fprint(rw, banner)
|
||||
|
||||
for {
|
||||
if _, err := fmt.Fprint(rw, "RoombaOS> "); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
line, err := shell.ReadLine(ctx, rw)
|
||||
if errors.Is(err, io.EOF) {
|
||||
fmt.Fprint(rw, "logout\r\n")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
result := state.dispatch(trimmed)
|
||||
|
||||
var output string
|
||||
if result.output != "" {
|
||||
output = result.output
|
||||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||||
output = strings.ReplaceAll(output, "\n", "\r\n")
|
||||
fmt.Fprintf(rw, "%s\r\n", output)
|
||||
}
|
||||
|
||||
if sess.Store != nil {
|
||||
if err := sess.Store.AppendSessionLog(ctx, sess.SessionID, trimmed, output); err != nil {
|
||||
return fmt.Errorf("append session log: %w", err)
|
||||
}
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("roomba")
|
||||
}
|
||||
|
||||
if result.exit {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func bootBanner() string {
|
||||
return `
|
||||
____ _ ___ ____
|
||||
| _ \ ___ ___ _ __ ___ | |__ __ _ / _ \/ ___|
|
||||
| |_) / _ \ / _ \| '_ ` + "`" + ` _ \| '_ \ / _` + "`" + ` | | | \___ \
|
||||
| _ < (_) | (_) | | | | | | |_) | (_| | |_| |___) |
|
||||
|_| \_\___/ \___/|_| |_| |_|_.__/ \__,_|\___/|____/
|
||||
|
||||
iRobot Roomba j7+ | RoombaOS v4.3.7
|
||||
Serial: RMB-7291-J7P-0482 | Firmware: 4.3.7-stable
|
||||
Battery: 73% | WiFi: Connected (SmartHome-5G)
|
||||
|
||||
Type 'help' for available commands.
|
||||
|
||||
`
|
||||
}
|
||||
|
||||
type room struct {
|
||||
name string
|
||||
areaSqFt int
|
||||
lastCleaned time.Time
|
||||
}
|
||||
|
||||
type scheduleEntry struct {
|
||||
day string
|
||||
time string
|
||||
}
|
||||
|
||||
type historyEntry struct {
|
||||
timestamp time.Time
|
||||
room string
|
||||
duration string
|
||||
note string
|
||||
}
|
||||
|
||||
type roombaState struct {
|
||||
battery int
|
||||
dustbin int
|
||||
status string
|
||||
rooms []room
|
||||
schedule []scheduleEntry
|
||||
cleanHistory []historyEntry
|
||||
}
|
||||
|
||||
type commandResult struct {
|
||||
output string
|
||||
exit bool
|
||||
}
|
||||
|
||||
func newRoombaState() *roombaState {
|
||||
now := time.Now()
|
||||
return &roombaState{
|
||||
battery: 73,
|
||||
dustbin: 61,
|
||||
status: "Docked",
|
||||
rooms: []room{
|
||||
{"Kitchen", 180, now.Add(-2 * time.Hour)},
|
||||
{"Living Room", 320, now.Add(-5 * time.Hour)},
|
||||
{"Bedroom", 200, now.Add(-26 * time.Hour)},
|
||||
{"Hallway", 60, now.Add(-5 * time.Hour)},
|
||||
{"Bathroom", 75, now.Add(-50 * time.Hour)},
|
||||
{"Cat's Room", 110, now.Add(-3 * time.Hour)},
|
||||
},
|
||||
schedule: []scheduleEntry{
|
||||
{"Monday", "09:00"},
|
||||
{"Wednesday", "09:00"},
|
||||
{"Friday", "09:00"},
|
||||
{"Saturday", "14:00"},
|
||||
},
|
||||
cleanHistory: []historyEntry{
|
||||
{now.Add(-2 * time.Hour), "Kitchen", "23 min", "Completed normally"},
|
||||
{now.Add(-3 * time.Hour), "Cat's Room", "18 min", "Cat detected - rerouting"},
|
||||
{now.Add(-5 * time.Hour), "Living Room", "34 min", "Encountered sock near couch"},
|
||||
{now.Add(-5*time.Hour - 40*time.Minute), "Hallway", "8 min", "Completed normally"},
|
||||
{now.Add(-26 * time.Hour), "Bedroom", "27 min", "Tangled in phone charger"},
|
||||
{now.Add(-50 * time.Hour), "Bathroom", "14 min", "Unidentified sticky substance detected"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) dispatch(input string) commandResult {
|
||||
parts := strings.Fields(input)
|
||||
if len(parts) == 0 {
|
||||
return commandResult{}
|
||||
}
|
||||
|
||||
cmd := strings.ToLower(parts[0])
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "help":
|
||||
return s.cmdHelp()
|
||||
case "status":
|
||||
return s.cmdStatus()
|
||||
case "clean":
|
||||
return s.cmdClean(args)
|
||||
case "dock":
|
||||
return s.cmdDock()
|
||||
case "map":
|
||||
return s.cmdMap()
|
||||
case "schedule":
|
||||
return s.cmdSchedule(args)
|
||||
case "history":
|
||||
return s.cmdHistory()
|
||||
case "diagnostics":
|
||||
return s.cmdDiagnostics()
|
||||
case "alerts":
|
||||
return s.cmdAlerts()
|
||||
case "reboot":
|
||||
return s.cmdReboot()
|
||||
case "exit", "logout":
|
||||
return commandResult{output: "Disconnecting from RoombaOS. Happy cleaning!", exit: true}
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("RoombaOS: unknown command '%s'. Type 'help' for available commands.", cmd)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHelp() commandResult {
|
||||
help := `Available commands:
|
||||
help - Show this help message
|
||||
status - Show robot status
|
||||
clean - Start full cleaning job
|
||||
clean room <name> - Clean a specific room
|
||||
dock - Return to dock
|
||||
map - Show floor plan and room list
|
||||
schedule - List cleaning schedule
|
||||
schedule add <day> <time> - Add scheduled cleaning
|
||||
schedule remove <day> - Remove scheduled cleaning
|
||||
history - Show recent cleaning history
|
||||
diagnostics - Run system diagnostics
|
||||
alerts - Show active alerts
|
||||
reboot - Reboot RoombaOS
|
||||
exit / logout - Disconnect`
|
||||
return commandResult{output: help}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdStatus() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== RoombaOS System Status ===\n")
|
||||
b.WriteString("Model: iRobot Roomba j7+\n")
|
||||
b.WriteString(fmt.Sprintf("Status: %s\n", s.status))
|
||||
b.WriteString(fmt.Sprintf("Battery: %d%%\n", s.battery))
|
||||
b.WriteString(fmt.Sprintf("Dustbin: %d%% full\n", s.dustbin))
|
||||
b.WriteString("Side brush: OK (142 hrs)\n")
|
||||
b.WriteString("Main brush: OK (98 hrs)\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("WiFi: Connected (SmartHome-5G)\n")
|
||||
b.WriteString("Signal: -38 dBm\n")
|
||||
b.WriteString("Alexa: Linked\n")
|
||||
b.WriteString("Google Home: Linked\n")
|
||||
b.WriteString("iRobot Home App: Connected\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString("Firmware: v4.3.7-stable\n")
|
||||
b.WriteString("LIDAR: Operational\n")
|
||||
b.WriteString("Clean Area Total: 12,847 sq ft (lifetime)")
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdClean(args []string) commandResult {
|
||||
if s.status == "Cleaning" {
|
||||
return commandResult{output: "Already cleaning. Use 'dock' to cancel and return to dock."}
|
||||
}
|
||||
|
||||
if len(args) >= 2 && strings.ToLower(args[0]) == "room" {
|
||||
roomName := strings.Join(args[1:], " ")
|
||||
for _, r := range s.rooms {
|
||||
if strings.EqualFold(r.name, roomName) {
|
||||
s.status = "Cleaning"
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting targeted clean: %s (%d sq ft)\nEstimated time: %d minutes\nUndocking... navigating to %s...",
|
||||
r.name, r.areaSqFt, r.areaSqFt/8, r.name,
|
||||
)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("Room '%s' not found. Use 'map' to see available rooms.", roomName)}
|
||||
}
|
||||
|
||||
if len(args) > 0 {
|
||||
return commandResult{output: "Usage: clean [room <name>]"}
|
||||
}
|
||||
|
||||
s.status = "Cleaning"
|
||||
var totalArea int
|
||||
for _, r := range s.rooms {
|
||||
totalArea += r.areaSqFt
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf(
|
||||
"Starting full house clean\nTotal area: %d sq ft across %d rooms\nEstimated time: %d minutes\nUndocking... beginning clean cycle...",
|
||||
totalArea, len(s.rooms), totalArea/8,
|
||||
)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDock() commandResult {
|
||||
if s.status == "Docked" {
|
||||
return commandResult{output: "Already docked."}
|
||||
}
|
||||
if s.status == "Returning to dock" {
|
||||
return commandResult{output: "Already returning to dock."}
|
||||
}
|
||||
s.status = "Returning to dock"
|
||||
return commandResult{output: "Cancelling current job. Returning to dock...\nEstimated arrival: 2 minutes"}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdMap() commandResult {
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Floor Plan ===\n\n")
|
||||
b.WriteString(" +------------+----------+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Kitchen | Bathroom |\n")
|
||||
b.WriteString(" | 180sqft | 75sqft |\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" +------+-----+----+-----+\n")
|
||||
b.WriteString(" | | | |\n")
|
||||
b.WriteString(" | Hall | Living | Cat |\n")
|
||||
b.WriteString(" | 60sf | Room | Rm |\n")
|
||||
b.WriteString(" | | 320sqft |110sf|\n")
|
||||
b.WriteString(" +------+ +-----+\n")
|
||||
b.WriteString(" | | |\n")
|
||||
b.WriteString(" | Bed +----------+\n")
|
||||
b.WriteString(" | room | [DOCK]\n")
|
||||
b.WriteString(" |200sf |\n")
|
||||
b.WriteString(" +------+\n")
|
||||
b.WriteString("\nRoom Details:\n")
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "ROOM", "AREA", "LAST CLEANED"))
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s\n", "----", "----", "------------"))
|
||||
for _, r := range s.rooms {
|
||||
ago := time.Since(r.lastCleaned).Truncate(time.Minute)
|
||||
b.WriteString(fmt.Sprintf(" %-15s %-10s %s ago\n", r.name, fmt.Sprintf("%d sqft", r.areaSqFt), formatDuration(ago)))
|
||||
}
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdSchedule(args []string) commandResult {
|
||||
if len(args) == 0 {
|
||||
return s.scheduleList()
|
||||
}
|
||||
|
||||
sub := strings.ToLower(args[0])
|
||||
switch sub {
|
||||
case "add":
|
||||
if len(args) < 3 {
|
||||
return commandResult{output: "Usage: schedule add <day> <time>\nExample: schedule add Tuesday 10:00"}
|
||||
}
|
||||
return s.scheduleAdd(args[1], args[2])
|
||||
case "remove":
|
||||
if len(args) < 2 {
|
||||
return commandResult{output: "Usage: schedule remove <day>"}
|
||||
}
|
||||
return s.scheduleRemove(args[1])
|
||||
default:
|
||||
return commandResult{output: fmt.Sprintf("Unknown schedule subcommand '%s'. Try: add, remove", sub)}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleList() commandResult {
|
||||
if len(s.schedule) == 0 {
|
||||
return commandResult{output: "No cleaning schedule configured."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning Schedule ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "DAY", "TIME"))
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", "---", "----"))
|
||||
for _, e := range s.schedule {
|
||||
b.WriteString(fmt.Sprintf(" %-12s %s\n", e.day, e.time))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d scheduled cleaning(s)", len(s.schedule)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleAdd(day, t string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
validDays := []string{"Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"}
|
||||
if !slices.Contains(validDays, day) {
|
||||
return commandResult{output: fmt.Sprintf("Invalid day '%s'. Use a day of the week (e.g. Monday, Tuesday).", day)}
|
||||
}
|
||||
|
||||
for _, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
return commandResult{output: fmt.Sprintf("Schedule for %s already exists. Remove it first.", day)}
|
||||
}
|
||||
}
|
||||
|
||||
s.schedule = append(s.schedule, scheduleEntry{day: day, time: t})
|
||||
return commandResult{output: fmt.Sprintf("Scheduled cleaning added: %s at %s", day, t)}
|
||||
}
|
||||
|
||||
func (s *roombaState) scheduleRemove(day string) commandResult {
|
||||
day = capitalizeFirst(strings.ToLower(day))
|
||||
for i, e := range s.schedule {
|
||||
if strings.EqualFold(e.day, day) {
|
||||
s.schedule = append(s.schedule[:i], s.schedule[i+1:]...)
|
||||
return commandResult{output: fmt.Sprintf("Removed schedule for %s.", day)}
|
||||
}
|
||||
}
|
||||
return commandResult{output: fmt.Sprintf("No schedule found for '%s'.", day)}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdHistory() commandResult {
|
||||
if len(s.cleanHistory) == 0 {
|
||||
return commandResult{output: "No cleaning history."}
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Cleaning History ===\n")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "TIME", "ROOM", "DURATION", "NOTE"))
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", "----", "----", "--------", "----"))
|
||||
for _, h := range s.cleanHistory {
|
||||
ts := h.timestamp.Format("2006-01-02 15:04")
|
||||
b.WriteString(fmt.Sprintf(" %-20s %-15s %-10s %s\n", ts, h.room, h.duration, h.note))
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d session(s) recorded", len(s.cleanHistory)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdDiagnostics() commandResult {
|
||||
diag := `Running RoombaOS diagnostics...
|
||||
|
||||
[1/8] Cliff sensors........... OK
|
||||
[2/8] Bumper sensor........... OK
|
||||
[3/8] Side brush motor........ OK (142 hrs until replacement)
|
||||
[4/8] Main brush motor........ OK (98 hrs until replacement)
|
||||
[5/8] Wheel motors............ OK (L: 1204 hrs, R: 1204 hrs)
|
||||
[6/8] LIDAR module............ OK (last calibrated 3 days ago)
|
||||
[7/8] Dustbin sensor.......... OK
|
||||
[8/8] WiFi module............. OK (signal: -38 dBm)
|
||||
|
||||
ALL SYSTEMS NOMINAL`
|
||||
return commandResult{output: diag}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdAlerts() commandResult {
|
||||
var alerts []string
|
||||
if s.dustbin >= 60 {
|
||||
alerts = append(alerts, fmt.Sprintf("WARNING: Dustbin %d%% full - consider emptying", s.dustbin))
|
||||
}
|
||||
alerts = append(alerts,
|
||||
"WARNING: Side brush replacement due in 12 hours",
|
||||
"INFO: Unidentified sticky substance detected in Kitchen",
|
||||
"INFO: Cat frequently blocking cleaning path in Cat's Room",
|
||||
"INFO: Firmware update available: v4.4.0-beta",
|
||||
"INFO: Filter replacement recommended in 14 days",
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("=== Active Alerts ===\n")
|
||||
for _, a := range alerts {
|
||||
b.WriteString(a + "\n")
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("\n%d alert(s) active", len(alerts)))
|
||||
return commandResult{output: b.String()}
|
||||
}
|
||||
|
||||
func (s *roombaState) cmdReboot() commandResult {
|
||||
reboot := `RoombaOS is rebooting...
|
||||
|
||||
Stopping navigation engine..... done
|
||||
Saving room map data........... done
|
||||
Flushing cleaning logs......... done
|
||||
Disconnecting from WiFi........ done
|
||||
|
||||
Rebooting now. Goodbye!`
|
||||
return commandResult{output: reboot, exit: true}
|
||||
}
|
||||
|
||||
func capitalizeFirst(s string) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
return strings.ToUpper(s[:1]) + s[1:]
|
||||
}
|
||||
|
||||
func formatDuration(d time.Duration) string {
|
||||
hours := int(d.Hours())
|
||||
minutes := int(d.Minutes()) % 60
|
||||
if hours >= 24 {
|
||||
days := hours / 24
|
||||
hours %= 24
|
||||
return fmt.Sprintf("%dd %dh", days, hours)
|
||||
}
|
||||
if hours > 0 {
|
||||
return fmt.Sprintf("%dh %dm", hours, minutes)
|
||||
}
|
||||
return fmt.Sprintf("%dm", minutes)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// Shell is the interface that all honeypot shell implementations must satisfy.
|
||||
|
||||
101
internal/shell/tetris/data.go
Normal file
101
internal/shell/tetris/data.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package tetris
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
// pieceType identifies a tetromino (0–6).
|
||||
type pieceType int
|
||||
|
||||
const (
|
||||
pieceI pieceType = iota
|
||||
pieceO
|
||||
pieceT
|
||||
pieceS
|
||||
pieceZ
|
||||
pieceJ
|
||||
pieceL
|
||||
)
|
||||
|
||||
const numPieceTypes = 7
|
||||
|
||||
// Standard Tetris colors.
|
||||
var pieceColors = [numPieceTypes]lipgloss.Color{
|
||||
lipgloss.Color("#00FFFF"), // I — cyan
|
||||
lipgloss.Color("#FFFF00"), // O — yellow
|
||||
lipgloss.Color("#AA00FF"), // T — purple
|
||||
lipgloss.Color("#00FF00"), // S — green
|
||||
lipgloss.Color("#FF0000"), // Z — red
|
||||
lipgloss.Color("#0000FF"), // J — blue
|
||||
lipgloss.Color("#FF8800"), // L — orange
|
||||
}
|
||||
|
||||
// Each piece has 4 rotations, each rotation is a list of (row, col) offsets
|
||||
// relative to the piece origin.
|
||||
type rotation [4][2]int
|
||||
|
||||
var pieces = [numPieceTypes][4]rotation{
|
||||
// I
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{0, 3}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{3, 0}},
|
||||
},
|
||||
// O
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
},
|
||||
// T
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// S
|
||||
{
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
{[2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
// Z
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 1}, [2]int{1, 0}, [2]int{1, 1}, [2]int{2, 0}},
|
||||
},
|
||||
// J
|
||||
{
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 0}, [2]int{2, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
},
|
||||
// L
|
||||
{
|
||||
{[2]int{0, 2}, [2]int{1, 0}, [2]int{1, 1}, [2]int{1, 2}},
|
||||
{[2]int{0, 0}, [2]int{1, 0}, [2]int{2, 0}, [2]int{2, 1}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{0, 2}, [2]int{1, 0}},
|
||||
{[2]int{0, 0}, [2]int{0, 1}, [2]int{1, 1}, [2]int{2, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
// spawnCol returns the starting column for a piece, centering it on the board.
|
||||
func spawnCol(pt pieceType, rot int) int {
|
||||
shape := pieces[pt][rot]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
width := maxC - minC + 1
|
||||
return (boardCols - width) / 2
|
||||
}
|
||||
210
internal/shell/tetris/game.go
Normal file
210
internal/shell/tetris/game.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tetris
|
||||
|
||||
import "math/rand/v2"
|
||||
|
||||
const (
|
||||
boardRows = 20
|
||||
boardCols = 10
|
||||
)
|
||||
|
||||
// cell represents a single board cell. Zero value is empty.
|
||||
type cell struct {
|
||||
filled bool
|
||||
piece pieceType // which piece type filled this cell (for color)
|
||||
}
|
||||
|
||||
// gameState holds all mutable state for a Tetris game.
|
||||
type gameState struct {
|
||||
board [boardRows][boardCols]cell
|
||||
current pieceType
|
||||
currentRot int
|
||||
currentRow int
|
||||
currentCol int
|
||||
next pieceType
|
||||
score int
|
||||
level int
|
||||
lines int
|
||||
gameOver bool
|
||||
}
|
||||
|
||||
// newGame creates a new game state, optionally starting at a given level.
|
||||
func newGame(startLevel int) *gameState {
|
||||
g := &gameState{
|
||||
level: startLevel,
|
||||
next: pieceType(rand.IntN(numPieceTypes)),
|
||||
}
|
||||
g.spawnPiece()
|
||||
return g
|
||||
}
|
||||
|
||||
// spawnPiece pulls the next piece and generates a new next.
|
||||
func (g *gameState) spawnPiece() {
|
||||
g.current = g.next
|
||||
g.next = pieceType(rand.IntN(numPieceTypes))
|
||||
g.currentRot = 0
|
||||
g.currentRow = 0
|
||||
g.currentCol = spawnCol(g.current, 0)
|
||||
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
g.gameOver = true
|
||||
}
|
||||
}
|
||||
|
||||
// canPlace checks whether the piece fits at the given position.
|
||||
func (g *gameState) canPlace(pt pieceType, rot, row, col int) bool {
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if r < 0 || r >= boardRows || c < 0 || c >= boardCols {
|
||||
return false
|
||||
}
|
||||
if g.board[r][c].filled {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// moveLeft moves the current piece left if possible.
|
||||
func (g *gameState) moveLeft() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol-1) {
|
||||
g.currentCol--
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveRight moves the current piece right if possible.
|
||||
func (g *gameState) moveRight() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol+1) {
|
||||
g.currentCol++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// moveDown moves the current piece down one row. Returns false if it cannot.
|
||||
func (g *gameState) moveDown() bool {
|
||||
if g.canPlace(g.current, g.currentRot, g.currentRow+1, g.currentCol) {
|
||||
g.currentRow++
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// rotate rotates the current piece clockwise with wall kick attempts.
|
||||
func (g *gameState) rotate() bool {
|
||||
newRot := (g.currentRot + 1) % 4
|
||||
|
||||
// Try in-place first.
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol) {
|
||||
g.currentRot = newRot
|
||||
return true
|
||||
}
|
||||
|
||||
// Wall kick: try +-1 column offset.
|
||||
for _, offset := range []int{-1, 1} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// I piece: try +-2.
|
||||
if g.current == pieceI {
|
||||
for _, offset := range []int{-2, 2} {
|
||||
if g.canPlace(g.current, newRot, g.currentRow, g.currentCol+offset) {
|
||||
g.currentRot = newRot
|
||||
g.currentCol += offset
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ghostRow returns the row where the piece would land.
|
||||
func (g *gameState) ghostRow() int {
|
||||
row := g.currentRow
|
||||
for g.canPlace(g.current, g.currentRot, row+1, g.currentCol) {
|
||||
row++
|
||||
}
|
||||
return row
|
||||
}
|
||||
|
||||
// hardDrop drops the piece to the bottom and returns the number of rows dropped.
|
||||
func (g *gameState) hardDrop() int {
|
||||
ghost := g.ghostRow()
|
||||
dropped := ghost - g.currentRow
|
||||
g.currentRow = ghost
|
||||
return dropped
|
||||
}
|
||||
|
||||
// lockPiece writes the current piece into the board.
|
||||
func (g *gameState) lockPiece() {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearLines removes completed rows and returns how many were cleared.
|
||||
func (g *gameState) clearLines() int {
|
||||
cleared := 0
|
||||
for r := boardRows - 1; r >= 0; r-- {
|
||||
full := true
|
||||
for c := range boardCols {
|
||||
if !g.board[r][c].filled {
|
||||
full = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if full {
|
||||
cleared++
|
||||
// Shift everything above down.
|
||||
for rr := r; rr > 0; rr-- {
|
||||
g.board[rr] = g.board[rr-1]
|
||||
}
|
||||
g.board[0] = [boardCols]cell{}
|
||||
r++ // re-check this row since we shifted
|
||||
}
|
||||
}
|
||||
return cleared
|
||||
}
|
||||
|
||||
// NES-style scoring multipliers per lines cleared.
|
||||
var lineScoreMultipliers = [5]int{0, 40, 100, 300, 1200}
|
||||
|
||||
// addScore updates score, lines, and level after clearing rows.
|
||||
func (g *gameState) addScore(linesCleared int) {
|
||||
if linesCleared > 0 && linesCleared <= 4 {
|
||||
g.score += lineScoreMultipliers[linesCleared] * (g.level + 1)
|
||||
}
|
||||
g.lines += linesCleared
|
||||
|
||||
// Level up every 10 lines.
|
||||
newLevel := g.lines / 10
|
||||
if newLevel > g.level {
|
||||
g.level = newLevel
|
||||
}
|
||||
}
|
||||
|
||||
// afterLock locks the piece, clears lines, scores, and spawns the next piece.
|
||||
// Returns the number of lines cleared.
|
||||
func (g *gameState) afterLock() int {
|
||||
g.lockPiece()
|
||||
cleared := g.clearLines()
|
||||
g.addScore(cleared)
|
||||
g.spawnPiece()
|
||||
return cleared
|
||||
}
|
||||
|
||||
// tickInterval returns the gravity interval in milliseconds for the current level.
|
||||
func tickInterval(level int) int {
|
||||
return max(800-level*60, 100)
|
||||
}
|
||||
331
internal/shell/tetris/model.go
Normal file
331
internal/shell/tetris/model.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
type screen int
|
||||
|
||||
const (
|
||||
screenTitle screen = iota
|
||||
screenGame
|
||||
screenGameOver
|
||||
)
|
||||
|
||||
type tickMsg time.Time
|
||||
type lockMsg time.Time
|
||||
|
||||
const lockDelay = 500 * time.Millisecond
|
||||
|
||||
type model struct {
|
||||
sess *shell.SessionContext
|
||||
difficulty string
|
||||
screen screen
|
||||
game *gameState
|
||||
quitting bool
|
||||
height int
|
||||
keypresses int
|
||||
locking bool // true when piece has landed and lock delay is active
|
||||
}
|
||||
|
||||
func newModel(sess *shell.SessionContext, difficulty string) *model {
|
||||
return &model{
|
||||
sess: sess,
|
||||
difficulty: difficulty,
|
||||
screen: screenTitle,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if m.quitting {
|
||||
return m, tea.Quit
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.height = msg.Height
|
||||
return m, nil
|
||||
case tea.KeyMsg:
|
||||
m.keypresses++
|
||||
if msg.Type == tea.KeyCtrlC {
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.gameScore(), m.gameLevel(), m.gameLines(), m.keypresses), "SESSION ENDED"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
return m.updateTitle(msg)
|
||||
case screenGame:
|
||||
return m.updateGame(msg)
|
||||
case screenGameOver:
|
||||
return m.updateGameOver(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) View() string {
|
||||
var content string
|
||||
switch m.screen {
|
||||
case screenTitle:
|
||||
content = m.titleView()
|
||||
case screenGame:
|
||||
content = gameView(m.game)
|
||||
case screenGameOver:
|
||||
content = m.gameOverView()
|
||||
}
|
||||
|
||||
return gameFrame(content, m.height)
|
||||
}
|
||||
|
||||
// --- Title screen ---
|
||||
|
||||
func (m *model) titleView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ████████╗███████╗████████╗██████╗ ██╗███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚══██╔══╝██╔════╝╚══██╔══╝██╔══██╗██║██╔════╝"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ █████╗ ██║ ██████╔╝██║███████╗"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ██╔══╝ ██║ ██╔══██╗██║╚════██║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ██║ ███████╗ ██║ ██║ ██║██║███████║"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═╝╚══════╝"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(" Press any key to start"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateTitle(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if _, ok := msg.(tea.KeyMsg); ok {
|
||||
m.screen = screenGame
|
||||
var startLevel int
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "GAME START", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game screen ---
|
||||
|
||||
func (m *model) scheduleTick() tea.Cmd {
|
||||
ms := tickInterval(m.game.level)
|
||||
if m.difficulty == "easy" {
|
||||
ms = max(1000-m.game.level*60, 150)
|
||||
}
|
||||
return tea.Tick(time.Duration(ms)*time.Millisecond, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *model) scheduleLock() tea.Cmd {
|
||||
return tea.Tick(lockDelay, func(t time.Time) tea.Msg {
|
||||
return lockMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
// performLock locks the piece, clears lines, and returns commands for logging
|
||||
// and scheduling the next tick. Returns nil if game over (goToGameOver is
|
||||
// included in the returned batch).
|
||||
func (m *model) performLock() tea.Cmd {
|
||||
m.locking = false
|
||||
cleared := m.game.afterLock()
|
||||
if m.game.gameOver {
|
||||
return tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("GAME OVER score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "GAME OVER"),
|
||||
m.goToGameOver(),
|
||||
)
|
||||
}
|
||||
var cmds []tea.Cmd
|
||||
cmds = append(cmds, m.scheduleTick())
|
||||
if cleared > 0 {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LINES %d score=%d", cleared, m.game.score), fmt.Sprintf("total=%d", m.game.lines)))
|
||||
prevLevel := (m.game.lines - cleared) / 10
|
||||
if m.game.level > prevLevel {
|
||||
cmds = append(cmds, logAction(m.sess, fmt.Sprintf("LEVEL UP %d", m.game.level), fmt.Sprintf("score=%d", m.game.score)))
|
||||
}
|
||||
}
|
||||
return tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *model) updateGame(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case lockMsg:
|
||||
if m.game.gameOver || !m.locking {
|
||||
return m, nil
|
||||
}
|
||||
// Lock delay expired — lock the piece now.
|
||||
return m, m.performLock()
|
||||
|
||||
case tickMsg:
|
||||
if m.game.gameOver || m.locking {
|
||||
return m, nil
|
||||
}
|
||||
if !m.game.moveDown() {
|
||||
// Piece landed — start lock delay instead of locking immediately.
|
||||
m.locking = true
|
||||
return m, m.scheduleLock()
|
||||
}
|
||||
return m, m.scheduleTick()
|
||||
|
||||
case tea.KeyMsg:
|
||||
if m.game.gameOver {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "left":
|
||||
m.game.moveLeft()
|
||||
// If piece can now drop further, cancel lock delay.
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "right":
|
||||
m.game.moveRight()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case "down":
|
||||
if m.game.moveDown() {
|
||||
m.game.score++ // soft drop bonus
|
||||
if m.locking {
|
||||
m.locking = false
|
||||
}
|
||||
}
|
||||
case "up", "z":
|
||||
m.game.rotate()
|
||||
if m.locking && m.game.canPlace(m.game.current, m.game.currentRot, m.game.currentRow+1, m.game.currentCol) {
|
||||
m.locking = false
|
||||
}
|
||||
case " ":
|
||||
m.locking = false
|
||||
dropped := m.game.hardDrop()
|
||||
m.game.score += dropped * 2
|
||||
return m, m.performLock()
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// --- Game over screen ---
|
||||
|
||||
func (m *model) goToGameOver() tea.Cmd {
|
||||
m.screen = screenGameOver
|
||||
return tea.ClearScreen
|
||||
}
|
||||
|
||||
func (m *model) gameOverView() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(titleStyle.Render(" GAME OVER"))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Score: %s", formatScore(m.game.score))))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Level: %d", m.game.level)))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(baseStyle.Render(fmt.Sprintf(" Lines: %d", m.game.lines)))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(dimStyle.Render(" R - Play again"))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(dimStyle.Render(" Q - Quit"))
|
||||
b.WriteString("\n")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) updateGameOver(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if keyMsg, ok := msg.(tea.KeyMsg); ok {
|
||||
switch keyMsg.String() {
|
||||
case "r":
|
||||
startLevel := 0
|
||||
if m.difficulty == "hard" {
|
||||
startLevel = 5
|
||||
}
|
||||
m.game = newGame(startLevel)
|
||||
m.screen = screenGame
|
||||
m.keypresses = 0
|
||||
return m, tea.Batch(
|
||||
tea.ClearScreen,
|
||||
m.scheduleTick(),
|
||||
logAction(m.sess, "RESTART", fmt.Sprintf("difficulty=%s", m.difficulty)),
|
||||
)
|
||||
case "q":
|
||||
m.quitting = true
|
||||
return m, tea.Batch(
|
||||
logAction(m.sess, fmt.Sprintf("QUIT score=%d level=%d lines=%d keys=%d", m.game.score, m.game.level, m.game.lines, m.keypresses), "PLAYER QUIT"),
|
||||
tea.Quit,
|
||||
)
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Helper methods for safe access when game may be nil.
|
||||
func (m *model) gameScore() int {
|
||||
if m.game != nil {
|
||||
return m.game.score
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLevel() int {
|
||||
if m.game != nil {
|
||||
return m.game.level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *model) gameLines() int {
|
||||
if m.game != nil {
|
||||
return m.game.lines
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// logAction returns a tea.Cmd that logs an action to the session store.
|
||||
func logAction(sess *shell.SessionContext, input, output string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if sess.Store != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = sess.Store.AppendSessionLog(ctx, sess.SessionID, input, output)
|
||||
}
|
||||
if sess.OnCommand != nil {
|
||||
sess.OnCommand("tetris")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
286
internal/shell/tetris/style.go
Normal file
286
internal/shell/tetris/style.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
const termWidth = 80
|
||||
|
||||
var (
|
||||
colorWhite = lipgloss.Color("#FFFFFF")
|
||||
colorDim = lipgloss.Color("#555555")
|
||||
colorBlack = lipgloss.Color("#000000")
|
||||
colorGhost = lipgloss.Color("#333333")
|
||||
)
|
||||
|
||||
var (
|
||||
baseStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack)
|
||||
|
||||
dimStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#00FFFF")).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
|
||||
sidebarLabelStyle = lipgloss.NewStyle().
|
||||
Foreground(colorDim).
|
||||
Background(colorBlack)
|
||||
|
||||
sidebarValueStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWhite).
|
||||
Background(colorBlack).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// cellStyle returns a style for a filled cell of a given piece type.
|
||||
func cellStyle(pt pieceType) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(pieceColors[pt]).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// ghostStyle returns a dimmed style for the ghost piece.
|
||||
func ghostCellStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(colorGhost).
|
||||
Background(colorBlack)
|
||||
}
|
||||
|
||||
// renderBoard renders the board, current piece, and ghost piece as a string.
|
||||
func renderBoard(g *gameState) string {
|
||||
// Build a display grid that includes the current piece and ghost.
|
||||
type displayCell struct {
|
||||
filled bool
|
||||
ghost bool
|
||||
piece pieceType
|
||||
}
|
||||
var grid [boardRows][boardCols]displayCell
|
||||
|
||||
// Copy locked cells.
|
||||
for r := range boardRows {
|
||||
for c := range boardCols {
|
||||
if g.board[r][c].filled {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.board[r][c].piece}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ghost piece.
|
||||
ghostR := g.ghostRow()
|
||||
if ghostR != g.currentRow {
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := ghostR+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols && !grid[r][c].filled {
|
||||
grid[r][c] = displayCell{ghost: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Current piece.
|
||||
shape := pieces[g.current][g.currentRot]
|
||||
for _, off := range shape {
|
||||
r, c := g.currentRow+off[0], g.currentCol+off[1]
|
||||
if r >= 0 && r < boardRows && c >= 0 && c < boardCols {
|
||||
grid[r][c] = displayCell{filled: true, piece: g.current}
|
||||
}
|
||||
}
|
||||
|
||||
// Render grid.
|
||||
var b strings.Builder
|
||||
borderStyle := dimStyle
|
||||
|
||||
for _, row := range grid {
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
for _, dc := range row {
|
||||
switch {
|
||||
case dc.filled:
|
||||
b.WriteString(cellStyle(dc.piece).Render("[]"))
|
||||
case dc.ghost:
|
||||
b.WriteString(ghostCellStyle().Render("::"))
|
||||
default:
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(borderStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
b.WriteString(borderStyle.Render("+" + strings.Repeat("--", boardCols) + "+"))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// renderNextPiece renders the "next piece" preview box.
|
||||
func renderNextPiece(pt pieceType) string {
|
||||
shape := pieces[pt][0]
|
||||
// Determine bounding box.
|
||||
minR, maxR := shape[0][0], shape[0][0]
|
||||
minC, maxC := shape[0][1], shape[0][1]
|
||||
for _, off := range shape {
|
||||
if off[0] < minR {
|
||||
minR = off[0]
|
||||
}
|
||||
if off[0] > maxR {
|
||||
maxR = off[0]
|
||||
}
|
||||
if off[1] < minC {
|
||||
minC = off[1]
|
||||
}
|
||||
if off[1] > maxC {
|
||||
maxC = off[1]
|
||||
}
|
||||
}
|
||||
|
||||
rows := maxR - minR + 1
|
||||
cols := maxC - minC + 1
|
||||
|
||||
// Build a small grid.
|
||||
grid := make([][]bool, rows)
|
||||
for i := range grid {
|
||||
grid[i] = make([]bool, cols)
|
||||
}
|
||||
for _, off := range shape {
|
||||
grid[off[0]-minR][off[1]-minC] = true
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
boxWidth := 8 // chars for the box interior
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
b.WriteString("\n")
|
||||
|
||||
for r := range rows {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
// Center the piece in the box.
|
||||
pieceWidth := cols * 2
|
||||
leftPad := (boxWidth - pieceWidth) / 2
|
||||
rightPad := boxWidth - pieceWidth - leftPad
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", leftPad)))
|
||||
for c := range cols {
|
||||
if grid[r][c] {
|
||||
b.WriteString(cellStyle(pt).Render("[]"))
|
||||
} else {
|
||||
b.WriteString(baseStyle.Render(" "))
|
||||
}
|
||||
}
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", rightPad)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Fill remaining rows in the box (max 4 rows for I piece).
|
||||
for r := rows; r < 2; r++ {
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString(baseStyle.Render(strings.Repeat(" ", boxWidth)))
|
||||
b.WriteString(dimStyle.Render("|"))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(dimStyle.Render("+" + strings.Repeat("-", boxWidth) + "+"))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// formatScore formats a score with comma separators.
|
||||
func formatScore(n int) string {
|
||||
s := fmt.Sprintf("%d", n)
|
||||
if len(s) <= 3 {
|
||||
return s
|
||||
}
|
||||
var parts []string
|
||||
for len(s) > 3 {
|
||||
parts = append([]string{s[len(s)-3:]}, parts...)
|
||||
s = s[:len(s)-3]
|
||||
}
|
||||
parts = append([]string{s}, parts...)
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// gameView combines the board and sidebar into the game screen.
|
||||
func gameView(g *gameState) string {
|
||||
boardStr := renderBoard(g)
|
||||
boardLines := strings.Split(boardStr, "\n")
|
||||
|
||||
nextStr := renderNextPiece(g.next)
|
||||
nextLines := strings.Split(nextStr, "\n")
|
||||
|
||||
// Build sidebar lines.
|
||||
var sidebar []string
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" NEXT:"))
|
||||
sidebar = append(sidebar, nextLines...)
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" SCORE: ")+sidebarValueStyle.Render(formatScore(g.score)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LEVEL: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.level)))
|
||||
sidebar = append(sidebar, sidebarLabelStyle.Render(" LINES: ")+sidebarValueStyle.Render(fmt.Sprintf("%d", g.lines)))
|
||||
sidebar = append(sidebar, "")
|
||||
sidebar = append(sidebar, dimStyle.Render(" Controls:"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" <- -> Move"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Up/Z Rotate"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Down Soft drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Space Hard drop"))
|
||||
sidebar = append(sidebar, dimStyle.Render(" Q Quit"))
|
||||
|
||||
// Combine board and sidebar side by side.
|
||||
var b strings.Builder
|
||||
maxLines := max(len(boardLines), len(sidebar))
|
||||
|
||||
for i := range maxLines {
|
||||
boardLine := ""
|
||||
if i < len(boardLines) {
|
||||
boardLine = boardLines[i]
|
||||
}
|
||||
sidebarLine := ""
|
||||
if i < len(sidebar) {
|
||||
sidebarLine = sidebar[i]
|
||||
}
|
||||
|
||||
// Pad board to fixed width (| + 10*2 + | = 22 chars visual).
|
||||
b.WriteString(boardLine)
|
||||
b.WriteString(sidebarLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// padLine pads a single line to termWidth.
|
||||
func padLine(line string) string {
|
||||
w := lipgloss.Width(line)
|
||||
if w >= termWidth {
|
||||
return line
|
||||
}
|
||||
return line + baseStyle.Render(strings.Repeat(" ", termWidth-w))
|
||||
}
|
||||
|
||||
// padLines pads every line in a multi-line string to termWidth.
|
||||
func padLines(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = padLine(line)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// gameFrame wraps content with padding to fill the terminal.
|
||||
func gameFrame(content string, height int) string {
|
||||
var b strings.Builder
|
||||
b.WriteString(content)
|
||||
|
||||
// Pad with blank lines to fill terminal height.
|
||||
if height > 0 {
|
||||
contentLines := strings.Count(content, "\n") + 1
|
||||
blankLine := baseStyle.Render(strings.Repeat(" ", termWidth))
|
||||
for i := contentLines; i < height; i++ {
|
||||
b.WriteString(blankLine)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
return padLines(b.String())
|
||||
}
|
||||
66
internal/shell/tetris/tetris.go
Normal file
66
internal/shell/tetris/tetris.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
)
|
||||
|
||||
const sessionTimeout = 10 * time.Minute
|
||||
|
||||
// TetrisShell is a Tetris game TUI for the honeypot.
|
||||
type TetrisShell struct{}
|
||||
|
||||
// NewTetrisShell returns a new TetrisShell instance.
|
||||
func NewTetrisShell() *TetrisShell {
|
||||
return &TetrisShell{}
|
||||
}
|
||||
|
||||
func (t *TetrisShell) Name() string { return "tetris" }
|
||||
func (t *TetrisShell) Description() string { return "Tetris game TUI" }
|
||||
|
||||
func (t *TetrisShell) Handle(ctx context.Context, sess *shell.SessionContext, rw io.ReadWriteCloser) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, sessionTimeout)
|
||||
defer cancel()
|
||||
|
||||
difficulty := configString(sess.ShellConfig, "difficulty", "normal")
|
||||
|
||||
m := newModel(sess, difficulty)
|
||||
p := tea.NewProgram(m,
|
||||
tea.WithInput(rw),
|
||||
tea.WithOutput(rw),
|
||||
tea.WithAltScreen(),
|
||||
)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := p.Run()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
p.Quit()
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// configString reads a string from the shell config map with a default.
|
||||
func configString(cfg map[string]any, key, defaultVal string) string {
|
||||
if cfg == nil {
|
||||
return defaultVal
|
||||
}
|
||||
if v, ok := cfg[key]; ok {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
582
internal/shell/tetris/tetris_test.go
Normal file
582
internal/shell/tetris/tetris_test.go
Normal file
@@ -0,0 +1,582 @@
|
||||
package tetris
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"code.t-juice.club/torjus/oubliette/internal/shell"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// newTestModel creates a model with a test session context.
|
||||
func newTestModel(t *testing.T) (*model, *storage.MemoryStore) {
|
||||
t.Helper()
|
||||
store := storage.NewMemoryStore()
|
||||
sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "player", "tetris", "")
|
||||
sess := &shell.SessionContext{
|
||||
SessionID: sessID,
|
||||
Username: "player",
|
||||
Store: store,
|
||||
}
|
||||
m := newModel(sess, "normal")
|
||||
return m, store
|
||||
}
|
||||
|
||||
// sendKey sends a single key message to the model and returns the command.
|
||||
func sendKey(m *model, key string) tea.Cmd {
|
||||
var msg tea.KeyMsg
|
||||
switch key {
|
||||
case "enter":
|
||||
msg = tea.KeyMsg{Type: tea.KeyEnter}
|
||||
case "up":
|
||||
msg = tea.KeyMsg{Type: tea.KeyUp}
|
||||
case "down":
|
||||
msg = tea.KeyMsg{Type: tea.KeyDown}
|
||||
case "left":
|
||||
msg = tea.KeyMsg{Type: tea.KeyLeft}
|
||||
case "right":
|
||||
msg = tea.KeyMsg{Type: tea.KeyRight}
|
||||
case "space":
|
||||
msg = tea.KeyMsg{Type: tea.KeySpace}
|
||||
case "ctrl+c":
|
||||
msg = tea.KeyMsg{Type: tea.KeyCtrlC}
|
||||
default:
|
||||
if len(key) == 1 {
|
||||
msg = tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}
|
||||
}
|
||||
}
|
||||
_, cmd := m.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// sendTick sends a tick message to the model.
|
||||
func sendTick(m *model) tea.Cmd {
|
||||
_, cmd := m.Update(tickMsg(time.Now()))
|
||||
return cmd
|
||||
}
|
||||
|
||||
// execCmds recursively executes tea.Cmd functions (including batches).
|
||||
func execCmds(cmd tea.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
msg := cmd()
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
for _, c := range batch {
|
||||
execCmds(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTetrisShellName(t *testing.T) {
|
||||
sh := NewTetrisShell()
|
||||
if sh.Name() != "tetris" {
|
||||
t.Errorf("Name() = %q, want %q", sh.Name(), "tetris")
|
||||
}
|
||||
if sh.Description() == "" {
|
||||
t.Error("Description() should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigString(t *testing.T) {
|
||||
cfg := map[string]any{
|
||||
"difficulty": "hard",
|
||||
}
|
||||
if got := configString(cfg, "difficulty", "normal"); got != "hard" {
|
||||
t.Errorf("configString() = %q, want %q", got, "hard")
|
||||
}
|
||||
if got := configString(cfg, "missing", "normal"); got != "normal" {
|
||||
t.Errorf("configString() = %q, want %q", got, "normal")
|
||||
}
|
||||
if got := configString(nil, "difficulty", "normal"); got != "normal" {
|
||||
t.Errorf("configString(nil) = %q, want %q", got, "normal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleScreenRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "████") {
|
||||
t.Error("title screen should show TETRIS logo")
|
||||
}
|
||||
if !strings.Contains(view, "Press any key") {
|
||||
t.Error("title screen should show 'Press any key'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTitleToGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
if m.screen != screenTitle {
|
||||
t.Fatalf("expected screenTitle, got %d", m.screen)
|
||||
}
|
||||
|
||||
sendKey(m, "enter")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after keypress, got %d", m.screen)
|
||||
}
|
||||
if m.game == nil {
|
||||
t.Fatal("game should be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameRenders(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "|") {
|
||||
t.Error("game view should contain board borders")
|
||||
}
|
||||
if !strings.Contains(view, "SCORE") {
|
||||
t.Error("game view should show SCORE")
|
||||
}
|
||||
if !strings.Contains(view, "LEVEL") {
|
||||
t.Error("game view should show LEVEL")
|
||||
}
|
||||
if !strings.Contains(view, "LINES") {
|
||||
t.Error("game view should show LINES")
|
||||
}
|
||||
if !strings.Contains(view, "NEXT") {
|
||||
t.Error("game view should show NEXT")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Pure game logic tests ---
|
||||
|
||||
func TestNewGame(t *testing.T) {
|
||||
g := newGame(0)
|
||||
if g.gameOver {
|
||||
t.Error("new game should not be game over")
|
||||
}
|
||||
if g.score != 0 {
|
||||
t.Errorf("initial score = %d, want 0", g.score)
|
||||
}
|
||||
if g.level != 0 {
|
||||
t.Errorf("initial level = %d, want 0", g.level)
|
||||
}
|
||||
if g.lines != 0 {
|
||||
t.Errorf("initial lines = %d, want 0", g.lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGameHardLevel(t *testing.T) {
|
||||
g := newGame(5)
|
||||
if g.level != 5 {
|
||||
t.Errorf("hard start level = %d, want 5", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveLeft(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveLeft()
|
||||
if g.currentCol != startCol-1 {
|
||||
t.Errorf("after moveLeft: col = %d, want %d", g.currentCol, startCol-1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveRight(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startCol := g.currentCol
|
||||
g.moveRight()
|
||||
if g.currentCol != startCol+1 {
|
||||
t.Errorf("after moveRight: col = %d, want %d", g.currentCol, startCol+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMoveDown(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
moved := g.moveDown()
|
||||
if !moved {
|
||||
t.Error("moveDown should succeed from starting position")
|
||||
}
|
||||
if g.currentRow != startRow+1 {
|
||||
t.Errorf("after moveDown: row = %d, want %d", g.currentRow, startRow+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveLeftBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way left.
|
||||
for range boardCols {
|
||||
g.moveLeft()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveLeft() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past left wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCannotMoveRightBeyondWall(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Move all the way right.
|
||||
for range boardCols {
|
||||
g.moveRight()
|
||||
}
|
||||
col := g.currentCol
|
||||
g.moveRight() // should not move further
|
||||
if g.currentCol != col {
|
||||
t.Errorf("should not move past right wall: col = %d, was %d", g.currentCol, col)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRot := g.currentRot
|
||||
g.rotate()
|
||||
// Rotation should change (possibly with wall kick).
|
||||
if g.currentRot == startRot {
|
||||
// Rotation might legitimately fail in some edge cases, so just check
|
||||
// that the game state is valid.
|
||||
if !g.canPlace(g.current, g.currentRot, g.currentRow, g.currentCol) {
|
||||
t.Error("piece should be in a valid position after rotate attempt")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDrop(t *testing.T) {
|
||||
g := newGame(0)
|
||||
startRow := g.currentRow
|
||||
dropped := g.hardDrop()
|
||||
if dropped == 0 {
|
||||
t.Error("hard drop should move piece down at least some rows from top")
|
||||
}
|
||||
if g.currentRow <= startRow {
|
||||
t.Errorf("after hardDrop: row = %d should be > %d", g.currentRow, startRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGhostRow(t *testing.T) {
|
||||
g := newGame(0)
|
||||
ghost := g.ghostRow()
|
||||
if ghost < g.currentRow {
|
||||
t.Errorf("ghost row %d should be >= current row %d", ghost, g.currentRow)
|
||||
}
|
||||
// Ghost should be at a position where moving down one more is impossible.
|
||||
if g.canPlace(g.current, g.currentRot, ghost+1, g.currentCol) {
|
||||
t.Error("ghost row should be the lowest valid position")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockPiece(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.hardDrop()
|
||||
pt := g.current
|
||||
row, col, rot := g.currentRow, g.currentCol, g.currentRot
|
||||
g.lockPiece()
|
||||
|
||||
// Verify that the piece's cells are now filled.
|
||||
shape := pieces[pt][rot]
|
||||
for _, off := range shape {
|
||||
r, c := row+off[0], col+off[1]
|
||||
if !g.board[r][c].filled {
|
||||
t.Errorf("cell (%d, %d) should be filled after lockPiece", r, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom row completely.
|
||||
for c := range boardCols {
|
||||
g.board[boardRows-1][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 1 {
|
||||
t.Errorf("clearLines() = %d, want 1", cleared)
|
||||
}
|
||||
// Bottom row should now be empty (shifted from above).
|
||||
for c := range boardCols {
|
||||
if g.board[boardRows-1][c].filled {
|
||||
t.Errorf("bottom row col %d should be empty after clearing", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearMultipleLines(t *testing.T) {
|
||||
g := newGame(0)
|
||||
// Fill the bottom 4 rows.
|
||||
for r := boardRows - 4; r < boardRows; r++ {
|
||||
for c := range boardCols {
|
||||
g.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
cleared := g.clearLines()
|
||||
if cleared != 4 {
|
||||
t.Errorf("clearLines() = %d, want 4", cleared)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScoring(t *testing.T) {
|
||||
tests := []struct {
|
||||
lines int
|
||||
level int
|
||||
want int
|
||||
}{
|
||||
{1, 0, 40},
|
||||
{2, 0, 100},
|
||||
{3, 0, 300},
|
||||
{4, 0, 1200},
|
||||
{1, 1, 80},
|
||||
{4, 2, 3600},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
g := newGame(tt.level)
|
||||
g.addScore(tt.lines)
|
||||
if g.score != tt.want {
|
||||
t.Errorf("score for %d lines at level %d = %d, want %d", tt.lines, tt.level, g.score, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLevelUp(t *testing.T) {
|
||||
g := newGame(0)
|
||||
g.lines = 9
|
||||
g.addScore(1) // This should push lines to 10, triggering level 1.
|
||||
if g.level != 1 {
|
||||
t.Errorf("level = %d, want 1 after 10 lines", g.level)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickInterval(t *testing.T) {
|
||||
if got := tickInterval(0); got != 800 {
|
||||
t.Errorf("tickInterval(0) = %d, want 800", got)
|
||||
}
|
||||
if got := tickInterval(5); got != 500 {
|
||||
t.Errorf("tickInterval(5) = %d, want 500", got)
|
||||
}
|
||||
// Floor at 100ms.
|
||||
if got := tickInterval(20); got != 100 {
|
||||
t.Errorf("tickInterval(20) = %d, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatScore(t *testing.T) {
|
||||
tests := []struct {
|
||||
n int
|
||||
want string
|
||||
}{
|
||||
{0, "0"},
|
||||
{100, "100"},
|
||||
{1250, "1,250"},
|
||||
{1000000, "1,000,000"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := formatScore(tt.n); got != tt.want {
|
||||
t.Errorf("formatScore(%d) = %q, want %q", tt.n, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGameOverScreen(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Force game over.
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
view := m.View()
|
||||
if !strings.Contains(view, "GAME OVER") {
|
||||
t.Error("game over screen should show GAME OVER")
|
||||
}
|
||||
if !strings.Contains(view, "Score") {
|
||||
t.Error("game over screen should show score")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestartFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "r")
|
||||
if m.screen != screenGame {
|
||||
t.Errorf("expected screenGame after restart, got %d", m.screen)
|
||||
}
|
||||
if m.game.gameOver {
|
||||
t.Error("game should not be over after restart")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGame(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuitFromGameOver(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
m.game.gameOver = true
|
||||
m.screen = screenGameOver
|
||||
|
||||
sendKey(m, "q")
|
||||
if !m.quitting {
|
||||
t.Error("should be quitting after pressing q in game over")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
scoreBefore := m.game.score
|
||||
sendKey(m, "down")
|
||||
if m.game.score != scoreBefore+1 {
|
||||
t.Errorf("score after soft drop = %d, want %d", m.game.score, scoreBefore+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHardDropScoring(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Hard drop gives 2 points per row dropped.
|
||||
sendKey(m, "space")
|
||||
if m.game.score < 2 {
|
||||
t.Errorf("score after hard drop = %d, should be at least 2", m.game.score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTickMovesDown(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
rowBefore := m.game.currentRow
|
||||
sendTick(m)
|
||||
// Piece should either move down by 1, or lock and spawn a new piece at top.
|
||||
movedDown := m.game.currentRow == rowBefore+1
|
||||
respawned := m.game.currentRow < rowBefore
|
||||
if !movedDown && !respawned && !m.game.gameOver {
|
||||
t.Errorf("tick should move piece down or lock+respawn: row was %d, now %d", rowBefore, m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionLogs(t *testing.T) {
|
||||
m, store := newTestModel(t)
|
||||
|
||||
// Press key to start game — returns a logAction cmd.
|
||||
cmd := sendKey(m, "enter")
|
||||
if cmd != nil {
|
||||
execCmds(cmd)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
found := false
|
||||
for _, log := range store.SessionLogs {
|
||||
if strings.Contains(log.Input, "GAME START") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected GAME START in session logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeypressCounter(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
sendKey(m, "left")
|
||||
sendKey(m, "right")
|
||||
sendKey(m, "down")
|
||||
|
||||
if m.keypresses != 4 { // enter + 3 game keys
|
||||
t.Errorf("keypresses = %d, want 4", m.keypresses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelay(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Drop piece to the bottom via ticks until it can't move down.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
if !m.locking {
|
||||
t.Fatal("piece should be in locking state after hitting bottom")
|
||||
}
|
||||
|
||||
// During lock delay, we should still be able to move left/right.
|
||||
colBefore := m.game.currentCol
|
||||
sendKey(m, "left")
|
||||
if m.game.currentCol >= colBefore {
|
||||
// Might not have moved if against wall, try right.
|
||||
sendKey(m, "right")
|
||||
}
|
||||
|
||||
// Sending a lockMsg should finalize the piece.
|
||||
m.Update(lockMsg(time.Now()))
|
||||
// After lock, a new piece should have spawned (row near top).
|
||||
if m.game.currentRow > 1 && !m.game.gameOver {
|
||||
t.Errorf("after lock delay, new piece should spawn near top, got row %d", m.game.currentRow)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLockDelayCancelledByDrop(t *testing.T) {
|
||||
m, _ := newTestModel(t)
|
||||
sendKey(m, "enter") // start game
|
||||
|
||||
// Build a ledge: fill rows 18-19 but leave column 0 empty.
|
||||
for r := boardRows - 2; r < boardRows; r++ {
|
||||
for c := 1; c < boardCols; c++ {
|
||||
m.game.board[r][c] = cell{filled: true, piece: pieceI}
|
||||
}
|
||||
}
|
||||
|
||||
// Move piece to column 0 area and drop it onto the ledge.
|
||||
for range boardCols {
|
||||
m.game.moveLeft()
|
||||
}
|
||||
// Tick down until locking.
|
||||
for range boardRows + 5 {
|
||||
if m.locking {
|
||||
break
|
||||
}
|
||||
sendTick(m)
|
||||
}
|
||||
|
||||
// If piece is on the ledge and we slide it to col 0 (open column),
|
||||
// the lock delay should cancel since it can fall further.
|
||||
// This test just validates the locking flag logic works.
|
||||
if m.locking {
|
||||
// Try moving — if piece can drop further, locking should cancel.
|
||||
sendKey(m, "left")
|
||||
// Whether locking cancels depends on the board state; just verify no crash.
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnCol(t *testing.T) {
|
||||
// All pieces should spawn roughly centered.
|
||||
for pt := range pieceType(numPieceTypes) {
|
||||
col := spawnCol(pt, 0)
|
||||
if col < 0 || col > boardCols-1 {
|
||||
t.Errorf("spawnCol(%d, 0) = %d, out of range", pt, col)
|
||||
}
|
||||
// Verify piece fits at spawn position.
|
||||
shape := pieces[pt][0]
|
||||
for _, off := range shape {
|
||||
c := col + off[1]
|
||||
if c < 0 || c >= boardCols {
|
||||
t.Errorf("piece %d overflows board at spawn: col+offset = %d", pt, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
217
internal/storage/instrumented.go
Normal file
217
internal/storage/instrumented.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
// InstrumentedStore wraps a Store and records query duration and errors
|
||||
// as Prometheus metrics for each method call.
|
||||
type InstrumentedStore struct {
|
||||
store Store
|
||||
queryDuration *prometheus.HistogramVec
|
||||
queryErrors *prometheus.CounterVec
|
||||
}
|
||||
|
||||
// NewInstrumentedStore returns a new InstrumentedStore wrapping the given store.
|
||||
func NewInstrumentedStore(store Store, queryDuration *prometheus.HistogramVec, queryErrors *prometheus.CounterVec) *InstrumentedStore {
|
||||
return &InstrumentedStore{
|
||||
store: store,
|
||||
queryDuration: queryDuration,
|
||||
queryErrors: queryErrors,
|
||||
}
|
||||
}
|
||||
|
||||
func observe[T any](s *InstrumentedStore, method string, fn func() (T, error)) (T, error) {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
v, err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
func observeErr(s *InstrumentedStore, method string, fn func() error) error {
|
||||
timer := prometheus.NewTimer(s.queryDuration.WithLabelValues(method))
|
||||
err := fn()
|
||||
timer.ObserveDuration()
|
||||
if err != nil {
|
||||
s.queryErrors.WithLabelValues(method).Inc()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) RecordLoginAttempt(ctx context.Context, username, password, ip, country string) error {
|
||||
return observeErr(s, "RecordLoginAttempt", func() error {
|
||||
return s.store.RecordLoginAttempt(ctx, username, password, ip, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CreateSession(ctx context.Context, ip, username, shellName, country string) (string, error) {
|
||||
return observe(s, "CreateSession", func() (string, error) {
|
||||
return s.store.CreateSession(ctx, ip, username, shellName, country)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) EndSession(ctx context.Context, sessionID string, disconnectedAt time.Time) error {
|
||||
return observeErr(s, "EndSession", func() error {
|
||||
return s.store.EndSession(ctx, sessionID, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) UpdateHumanScore(ctx context.Context, sessionID string, score float64) error {
|
||||
return observeErr(s, "UpdateHumanScore", func() error {
|
||||
return s.store.UpdateHumanScore(ctx, sessionID, score)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) SetExecCommand(ctx context.Context, sessionID string, command string) error {
|
||||
return observeErr(s, "SetExecCommand", func() error {
|
||||
return s.store.SetExecCommand(ctx, sessionID, command)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionLog(ctx context.Context, sessionID, input, output string) error {
|
||||
return observeErr(s, "AppendSessionLog", func() error {
|
||||
return s.store.AppendSessionLog(ctx, sessionID, input, output)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) DeleteRecordsBefore(ctx context.Context, cutoff time.Time) (int64, error) {
|
||||
return observe(s, "DeleteRecordsBefore", func() (int64, error) {
|
||||
return s.store.DeleteRecordsBefore(ctx, cutoff)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
return observe(s, "GetDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetDashboardStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopUsernames(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopUsernames(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopPasswords(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopPasswords(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopIPs(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopIPs(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopCountries(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopCountries(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetTopExecCommands(ctx context.Context, limit int) ([]TopEntry, error) {
|
||||
return observe(s, "GetTopExecCommands", func() ([]TopEntry, error) {
|
||||
return s.store.GetTopExecCommands(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
return observe(s, "GetRecentSessions", func() ([]Session, error) {
|
||||
return s.store.GetRecentSessions(ctx, limit, activeOnly)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
return observe(s, "GetFilteredSessions", func() ([]Session, error) {
|
||||
return s.store.GetFilteredSessions(ctx, limit, activeOnly, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSession(ctx context.Context, sessionID string) (*Session, error) {
|
||||
return observe(s, "GetSession", func() (*Session, error) {
|
||||
return s.store.GetSession(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionLogs(ctx context.Context, sessionID string) ([]SessionLog, error) {
|
||||
return observe(s, "GetSessionLogs", func() ([]SessionLog, error) {
|
||||
return s.store.GetSessionLogs(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) AppendSessionEvents(ctx context.Context, events []SessionEvent) error {
|
||||
return observeErr(s, "AppendSessionEvents", func() error {
|
||||
return s.store.AppendSessionEvents(ctx, events)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetSessionEvents(ctx context.Context, sessionID string) ([]SessionEvent, error) {
|
||||
return observe(s, "GetSessionEvents", func() ([]SessionEvent, error) {
|
||||
return s.store.GetSessionEvents(ctx, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) CloseActiveSessions(ctx context.Context, disconnectedAt time.Time) (int64, error) {
|
||||
return observe(s, "CloseActiveSessions", func() (int64, error) {
|
||||
return s.store.CloseActiveSessions(ctx, disconnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetAttemptsOverTime(ctx context.Context, days int, since, until *time.Time) ([]TimeSeriesPoint, error) {
|
||||
return observe(s, "GetAttemptsOverTime", func() ([]TimeSeriesPoint, error) {
|
||||
return s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetHourlyPattern(ctx context.Context, since, until *time.Time) ([]HourlyCount, error) {
|
||||
return observe(s, "GetHourlyPattern", func() ([]HourlyCount, error) {
|
||||
return s.store.GetHourlyPattern(ctx, since, until)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetCountryStats(ctx context.Context) ([]CountryCount, error) {
|
||||
return observe(s, "GetCountryStats", func() ([]CountryCount, error) {
|
||||
return s.store.GetCountryStats(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredDashboardStats(ctx context.Context, f DashboardFilter) (*DashboardStats, error) {
|
||||
return observe(s, "GetFilteredDashboardStats", func() (*DashboardStats, error) {
|
||||
return s.store.GetFilteredDashboardStats(ctx, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopUsernames(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopUsernames", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopUsernames(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopPasswords(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopPasswords", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopPasswords(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopIPs(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopIPs", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopIPs(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) GetFilteredTopCountries(ctx context.Context, limit int, f DashboardFilter) ([]TopEntry, error) {
|
||||
return observe(s, "GetFilteredTopCountries", func() ([]TopEntry, error) {
|
||||
return s.store.GetFilteredTopCountries(ctx, limit, f)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *InstrumentedStore) Close() error {
|
||||
return s.store.Close()
|
||||
}
|
||||
163
internal/storage/instrumented_test.go
Normal file
163
internal/storage/instrumented_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
)
|
||||
|
||||
func newTestInstrumented() (*InstrumentedStore, *prometheus.HistogramVec, *prometheus.CounterVec) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
store := NewMemoryStore()
|
||||
return NewInstrumentedStore(store, dur, errs), dur, errs
|
||||
}
|
||||
|
||||
func getHistogramCount(h *prometheus.HistogramVec, method string) uint64 {
|
||||
m := &dto.Metric{}
|
||||
h.WithLabelValues(method).(prometheus.Histogram).Write(m)
|
||||
return m.GetHistogram().GetSampleCount()
|
||||
}
|
||||
|
||||
func getCounterValue(c *prometheus.CounterVec, method string) float64 {
|
||||
m := &dto.Metric{}
|
||||
c.WithLabelValues(method).Write(m)
|
||||
return m.GetCounter().GetValue()
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreDelegation(t *testing.T) {
|
||||
s, dur, _ := newTestInstrumented()
|
||||
ctx := context.Background()
|
||||
|
||||
// RecordLoginAttempt should delegate and record duration.
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// CreateSession should delegate and return a valid ID.
|
||||
id, err := s.CreateSession(ctx, "1.2.3.4", "root", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Fatal("CreateSession returned empty ID")
|
||||
}
|
||||
if c := getHistogramCount(dur, "CreateSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
|
||||
// GetDashboardStats should delegate.
|
||||
stats, err := s.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDashboardStats: %v", err)
|
||||
}
|
||||
if stats == nil {
|
||||
t.Fatal("GetDashboardStats returned nil")
|
||||
}
|
||||
if c := getHistogramCount(dur, "GetDashboardStats"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreErrorCounting(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test_ec_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test_ec_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
// Error should be counted.
|
||||
err := s.EndSession(ctx, "nonexistent", time.Now())
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getHistogramCount(dur, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
if c := getCounterValue(errs, "EndSession"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
|
||||
// Successful call should not increment error counter.
|
||||
s2, _, errs2 := newTestInstrumented()
|
||||
err = s2.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLoginAttempt: %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs2, "RecordLoginAttempt"); c != 0 {
|
||||
t.Fatalf("expected error count 0, got %f", c)
|
||||
}
|
||||
}
|
||||
|
||||
// errorStore is a Store that returns errors for all methods.
|
||||
type errorStore struct {
|
||||
MemoryStore
|
||||
}
|
||||
|
||||
var errFake = errors.New("fake error")
|
||||
|
||||
func (s *errorStore) RecordLoginAttempt(context.Context, string, string, string, string) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func (s *errorStore) EndSession(context.Context, string, time.Time) error {
|
||||
return errFake
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreObserveErr(t *testing.T) {
|
||||
dur := prometheus.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "test2_query_duration_seconds",
|
||||
Help: "test",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 1},
|
||||
}, []string{"method"})
|
||||
errs := prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "test2_query_errors_total",
|
||||
Help: "test",
|
||||
}, []string{"method"})
|
||||
|
||||
es := &errorStore{}
|
||||
s := NewInstrumentedStore(es, dur, errs)
|
||||
ctx := context.Background()
|
||||
|
||||
err := s.RecordLoginAttempt(ctx, "root", "pass", "1.2.3.4", "US")
|
||||
if !errors.Is(err, errFake) {
|
||||
t.Fatalf("expected errFake, got %v", err)
|
||||
}
|
||||
if c := getCounterValue(errs, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected error count 1, got %f", c)
|
||||
}
|
||||
if c := getHistogramCount(dur, "RecordLoginAttempt"); c != 1 {
|
||||
t.Fatalf("expected 1 observation, got %d", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstrumentedStoreClose(t *testing.T) {
|
||||
s, _, _ := newTestInstrumented()
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -336,10 +336,26 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Count events per session.
|
||||
return m.collectSessions(limit, activeOnly, DashboardFilter{}), nil
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetFilteredSessions(_ context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.collectSessions(limit, activeOnly, f), nil
|
||||
}
|
||||
|
||||
// collectSessions gathers sessions matching filter criteria. Must be called with m.mu held.
|
||||
func (m *MemoryStore) collectSessions(limit int, activeOnly bool, f DashboardFilter) []Session {
|
||||
// Compute event counts and input bytes per session.
|
||||
eventCounts := make(map[string]int)
|
||||
inputBytes := make(map[string]int64)
|
||||
for _, e := range m.SessionEvents {
|
||||
eventCounts[e.SessionID]++
|
||||
if e.Direction == 0 {
|
||||
inputBytes[e.SessionID] += int64(len(e.Data))
|
||||
}
|
||||
}
|
||||
|
||||
var sessions []Session
|
||||
@@ -347,17 +363,54 @@ func (m *MemoryStore) GetRecentSessions(_ context.Context, limit int, activeOnly
|
||||
if activeOnly && s.DisconnectedAt != nil {
|
||||
continue
|
||||
}
|
||||
if !matchesSessionFilter(s, f) {
|
||||
continue
|
||||
}
|
||||
sess := *s
|
||||
sess.EventCount = eventCounts[s.ID]
|
||||
sess.InputBytes = inputBytes[s.ID]
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
|
||||
if f.SortBy == "input_bytes" {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].InputBytes > sessions[j].InputBytes
|
||||
})
|
||||
} else {
|
||||
sort.Slice(sessions, func(i, j int) bool {
|
||||
return sessions[i].ConnectedAt.After(sessions[j].ConnectedAt)
|
||||
})
|
||||
}
|
||||
|
||||
if limit > 0 && len(sessions) > limit {
|
||||
sessions = sessions[:limit]
|
||||
}
|
||||
return sessions, nil
|
||||
return sessions
|
||||
}
|
||||
|
||||
// matchesSessionFilter returns true if the session matches the given filter.
|
||||
func matchesSessionFilter(s *Session, f DashboardFilter) bool {
|
||||
if f.Since != nil && s.ConnectedAt.Before(*f.Since) {
|
||||
return false
|
||||
}
|
||||
if f.Until != nil && s.ConnectedAt.After(*f.Until) {
|
||||
return false
|
||||
}
|
||||
if f.IP != "" && s.IP != f.IP {
|
||||
return false
|
||||
}
|
||||
if f.Country != "" && s.Country != f.Country {
|
||||
return false
|
||||
}
|
||||
if f.Username != "" && s.Username != f.Username {
|
||||
return false
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
if s.HumanScore == nil || *s.HumanScore <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MemoryStore) GetTopExecCommands(_ context.Context, limit int) ([]TopEntry, error) {
|
||||
|
||||
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
3
internal/storage/migrations/005_add_query_indexes.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
CREATE INDEX idx_login_attempts_username ON login_attempts(username);
|
||||
CREATE INDEX idx_login_attempts_password ON login_attempts(password);
|
||||
CREATE INDEX idx_sessions_disconnected_at ON sessions(disconnected_at);
|
||||
@@ -25,8 +25,8 @@ func TestMigrateCreatesTablesAndVersion(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 4 {
|
||||
t.Errorf("version = %d, want 4", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d, want 5", version)
|
||||
}
|
||||
|
||||
// Verify tables exist by inserting into them.
|
||||
@@ -64,8 +64,8 @@ func TestMigrateIdempotent(t *testing.T) {
|
||||
if err := db.QueryRow(`SELECT version FROM schema_version`).Scan(&version); err != nil {
|
||||
t.Fatalf("query version: %v", err)
|
||||
}
|
||||
if version != 4 {
|
||||
t.Errorf("version = %d after double migrate, want 4", version)
|
||||
if version != 5 {
|
||||
t.Errorf("version = %d after double migrate, want 5", version)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -383,40 +383,104 @@ func (s *SQLiteStore) queryTopN(ctx context.Context, column string, limit int) (
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error) {
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id`
|
||||
if activeOnly {
|
||||
query += ` WHERE s.disconnected_at IS NULL`
|
||||
}
|
||||
query += ` GROUP BY s.id ORDER BY s.connected_at DESC LIMIT ?`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, limit)
|
||||
return s.scanSessions(ctx, query, limit)
|
||||
}
|
||||
|
||||
// buildSessionWhereClause builds a dynamic WHERE clause for session filtering.
|
||||
func buildSessionWhereClause(f DashboardFilter, activeOnly bool) (string, []any) {
|
||||
var clauses []string
|
||||
var args []any
|
||||
|
||||
if activeOnly {
|
||||
clauses = append(clauses, "s.disconnected_at IS NULL")
|
||||
}
|
||||
if f.Since != nil {
|
||||
clauses = append(clauses, "s.connected_at >= ?")
|
||||
args = append(args, f.Since.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.Until != nil {
|
||||
clauses = append(clauses, "s.connected_at <= ?")
|
||||
args = append(args, f.Until.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if f.IP != "" {
|
||||
clauses = append(clauses, "s.ip = ?")
|
||||
args = append(args, f.IP)
|
||||
}
|
||||
if f.Country != "" {
|
||||
clauses = append(clauses, "s.country = ?")
|
||||
args = append(args, f.Country)
|
||||
}
|
||||
if f.Username != "" {
|
||||
clauses = append(clauses, "s.username = ?")
|
||||
args = append(args, f.Username)
|
||||
}
|
||||
if f.HumanScoreAboveZero {
|
||||
clauses = append(clauses, "s.human_score > 0")
|
||||
}
|
||||
|
||||
if len(clauses) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
// validSessionSorts maps allowed SortBy values to SQL ORDER BY clauses.
|
||||
var validSessionSorts = map[string]string{
|
||||
"connected_at": "s.connected_at DESC",
|
||||
"input_bytes": "input_bytes DESC",
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error) {
|
||||
where, args := buildSessionWhereClause(f, activeOnly)
|
||||
args = append(args, limit)
|
||||
|
||||
orderBy := validSessionSorts["connected_at"]
|
||||
if mapped, ok := validSessionSorts[f.SortBy]; ok {
|
||||
orderBy = mapped
|
||||
}
|
||||
|
||||
//nolint:gosec // where/order clauses built from allowlisted constants, not raw user input
|
||||
query := `SELECT s.id, s.ip, s.country, s.username, s.shell_name, s.connected_at, s.disconnected_at, s.human_score, s.exec_command, COUNT(e.id) as event_count, COALESCE(SUM(CASE WHEN e.direction = 0 THEN LENGTH(e.data) ELSE 0 END), 0) as input_bytes FROM sessions s LEFT JOIN session_events e ON s.id = e.session_id` + where + ` GROUP BY s.id ORDER BY ` + orderBy + ` LIMIT ?`
|
||||
|
||||
return s.scanSessions(ctx, query, args...)
|
||||
}
|
||||
|
||||
// scanSessions executes a session query and scans the results.
|
||||
func (s *SQLiteStore) scanSessions(ctx context.Context, query string, args ...any) ([]Session, error) {
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying recent sessions: %w", err)
|
||||
return nil, fmt.Errorf("querying sessions: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var sessions []Session
|
||||
for rows.Next() {
|
||||
var s Session
|
||||
var sess Session
|
||||
var connectedAt string
|
||||
var disconnectedAt sql.NullString
|
||||
var humanScore sql.NullFloat64
|
||||
var execCommand sql.NullString
|
||||
if err := rows.Scan(&s.ID, &s.IP, &s.Country, &s.Username, &s.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &s.EventCount); err != nil {
|
||||
if err := rows.Scan(&sess.ID, &sess.IP, &sess.Country, &sess.Username, &sess.ShellName, &connectedAt, &disconnectedAt, &humanScore, &execCommand, &sess.EventCount, &sess.InputBytes); err != nil {
|
||||
return nil, fmt.Errorf("scanning session: %w", err)
|
||||
}
|
||||
s.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
sess.ConnectedAt, _ = time.Parse(time.RFC3339, connectedAt)
|
||||
if disconnectedAt.Valid {
|
||||
t, _ := time.Parse(time.RFC3339, disconnectedAt.String)
|
||||
s.DisconnectedAt = &t
|
||||
sess.DisconnectedAt = &t
|
||||
}
|
||||
if humanScore.Valid {
|
||||
s.HumanScore = &humanScore.Float64
|
||||
sess.HumanScore = &humanScore.Float64
|
||||
}
|
||||
if execCommand.Valid {
|
||||
s.ExecCommand = &execCommand.String
|
||||
sess.ExecCommand = &execCommand.String
|
||||
}
|
||||
sessions = append(sessions, s)
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
return sessions, rows.Err()
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type Session struct {
|
||||
HumanScore *float64
|
||||
ExecCommand *string
|
||||
EventCount int
|
||||
InputBytes int64
|
||||
}
|
||||
|
||||
// SessionLog represents a single log entry for a session.
|
||||
@@ -76,11 +77,13 @@ type CountryCount struct {
|
||||
|
||||
// DashboardFilter contains optional filters for dashboard queries.
|
||||
type DashboardFilter struct {
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
IP string
|
||||
Country string
|
||||
Username string
|
||||
HumanScoreAboveZero bool
|
||||
SortBy string
|
||||
}
|
||||
|
||||
// TopEntry represents a value and its count for top-N queries.
|
||||
@@ -137,6 +140,10 @@ type Store interface {
|
||||
// If activeOnly is true, only sessions with no disconnected_at are returned.
|
||||
GetRecentSessions(ctx context.Context, limit int, activeOnly bool) ([]Session, error)
|
||||
|
||||
// GetFilteredSessions returns sessions matching the given filter, ordered
|
||||
// by the filter's SortBy field (default: connected_at DESC).
|
||||
GetFilteredSessions(ctx context.Context, limit int, activeOnly bool, f DashboardFilter) ([]Session, error)
|
||||
|
||||
// GetSession returns a single session by ID.
|
||||
GetSession(ctx context.Context, sessionID string) (*Session, error)
|
||||
|
||||
|
||||
@@ -700,3 +700,192 @@ func TestGetRecentSessions(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestInputBytes(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("counts only input direction", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
events := []SessionEvent{
|
||||
{SessionID: id, Timestamp: now, Direction: 0, Data: []byte("ls\n")}, // 3 bytes input
|
||||
{SessionID: id, Timestamp: now.Add(100 * time.Millisecond), Direction: 1, Data: []byte("file1\nfile2\n")}, // 11 bytes output
|
||||
{SessionID: id, Timestamp: now.Add(200 * time.Millisecond), Direction: 0, Data: []byte("pwd\n")}, // 4 bytes input
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, events); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
// Only direction=0 data: "ls\n" (3) + "pwd\n" (4) = 7
|
||||
if sessions[0].InputBytes != 7 {
|
||||
t.Errorf("InputBytes = %d, want 7", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero when no events", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetRecentSessions(ctx, 10, false)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRecentSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].InputBytes != 0 {
|
||||
t.Errorf("InputBytes = %d, want 0", sessions[0].InputBytes)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetFilteredSessions(t *testing.T) {
|
||||
testStores(t, func(t *testing.T, newStore storeFactory) {
|
||||
t.Run("filter by human score", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create two sessions, one with human score > 0.
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.75); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{HumanScoreAboveZero: true})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sort by input bytes", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Session with more input (created first).
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id1, Timestamp: now, Direction: 0, Data: []byte("ls -la /tmp\n")},
|
||||
{SessionID: id1, Timestamp: now.Add(time.Millisecond), Direction: 0, Data: []byte("cat /etc/passwd\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Session with less input (created after id1, so would be first by connected_at).
|
||||
// Sleep >1s to ensure different RFC3339 timestamps in SQLite.
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.AppendSessionEvents(ctx, []SessionEvent{
|
||||
{SessionID: id2, Timestamp: now.Add(2 * time.Second), Direction: 0, Data: []byte("x\n")},
|
||||
}); err != nil {
|
||||
t.Fatalf("AppendSessionEvents: %v", err)
|
||||
}
|
||||
|
||||
// Default sort (connected_at DESC) should show id2 first.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id2 {
|
||||
t.Errorf("default sort: expected %s first, got %s", id2, sessions[0].ID)
|
||||
}
|
||||
|
||||
// Sort by input_bytes should show id1 first (more input).
|
||||
sessions, err = store.GetFilteredSessions(ctx, 50, false, DashboardFilter{SortBy: "input_bytes"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 2 {
|
||||
t.Fatalf("len = %d, want 2", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("input_bytes sort: expected %s first, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("combined filters", func(t *testing.T) {
|
||||
store := newStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
id1, err := store.CreateSession(ctx, "10.0.0.1", "root", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id1, 0.5); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Different country, also has score.
|
||||
id2, err := store.CreateSession(ctx, "10.0.0.2", "admin", "bash", "US")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
if err := store.UpdateHumanScore(ctx, id2, 0.8); err != nil {
|
||||
t.Fatalf("UpdateHumanScore: %v", err)
|
||||
}
|
||||
|
||||
// Same country CN but no score.
|
||||
_, err = store.CreateSession(ctx, "10.0.0.3", "test", "bash", "CN")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSession: %v", err)
|
||||
}
|
||||
|
||||
// Filter: CN + human score > 0 -> only id1.
|
||||
sessions, err := store.GetFilteredSessions(ctx, 50, false, DashboardFilter{
|
||||
Country: "CN",
|
||||
HumanScoreAboveZero: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFilteredSessions: %v", err)
|
||||
}
|
||||
if len(sessions) != 1 {
|
||||
t.Fatalf("len = %d, want 1", len(sessions))
|
||||
}
|
||||
if sessions[0].ID != id1 {
|
||||
t.Errorf("expected session %s, got %s", id1, sessions[0].ID)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
// dbContext returns a context detached from the HTTP request lifecycle with a
|
||||
// 30-second timeout. This prevents HTMX polling from canceling in-flight DB
|
||||
// queries when the browser aborts the previous XHR.
|
||||
func dbContext(r *http.Request) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.WithoutCancel(r.Context()), 30*time.Second)
|
||||
}
|
||||
|
||||
type dashboardData struct {
|
||||
Stats *storage.DashboardStats
|
||||
TopUsernames []storage.TopEntry
|
||||
@@ -22,7 +30,8 @@ type dashboardData struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
@@ -98,7 +107,10 @@ func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
stats, err := s.store.GetDashboardStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.store.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get dashboard stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -112,7 +124,10 @@ func (s *Server) handleFragmentStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Request) {
|
||||
sessions, err := s.store.GetRecentSessions(r.Context(), 50, true)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
sessions, err := s.store.GetRecentSessions(ctx, 50, true)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get active sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -125,6 +140,24 @@ func (s *Server) handleFragmentActiveSessions(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentRecentSessions(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
f := parseDashboardFilter(r)
|
||||
sessions, err := s.store.GetFilteredSessions(ctx, 50, false, f)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get filtered sessions", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := s.tmpl.dashboard.ExecuteTemplate(w, "recent_sessions", sessions); err != nil {
|
||||
s.logger.Error("failed to render recent sessions fragment", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
type sessionDetailData struct {
|
||||
Session *storage.Session
|
||||
Logs []storage.SessionLog
|
||||
@@ -132,7 +165,8 @@ type sessionDetailData struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionDetail(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
session, err := s.store.GetSession(ctx, sessionID)
|
||||
@@ -201,11 +235,13 @@ func parseDateParam(r *http.Request, name string) *time.Time {
|
||||
|
||||
func parseDashboardFilter(r *http.Request) storage.DashboardFilter {
|
||||
return storage.DashboardFilter{
|
||||
Since: parseDateParam(r, "since"),
|
||||
Until: parseDateParam(r, "until"),
|
||||
IP: r.URL.Query().Get("ip"),
|
||||
Country: r.URL.Query().Get("country"),
|
||||
Username: r.URL.Query().Get("username"),
|
||||
Since: parseDateParam(r, "since"),
|
||||
Until: parseDateParam(r, "until"),
|
||||
IP: r.URL.Query().Get("ip"),
|
||||
Country: r.URL.Query().Get("country"),
|
||||
Username: r.URL.Query().Get("username"),
|
||||
HumanScoreAboveZero: r.URL.Query().Get("human_score") == "1",
|
||||
SortBy: r.URL.Query().Get("sort"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,7 +265,10 @@ func (s *Server) handleAPIAttemptsOverTime(w http.ResponseWriter, r *http.Reques
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(r.Context(), days, since, until)
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
points, err := s.store.GetAttemptsOverTime(ctx, days, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get attempts over time", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -260,10 +299,13 @@ type apiHourlyPatternResponse struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleAPIHourlyPattern(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
since := parseDateParam(r, "since")
|
||||
until := parseDateParam(r, "until")
|
||||
|
||||
counts, err := s.store.GetHourlyPattern(r.Context(), since, until)
|
||||
counts, err := s.store.GetHourlyPattern(ctx, since, until)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get hourly pattern", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -291,7 +333,10 @@ type apiCountryStatsResponse struct {
|
||||
}
|
||||
|
||||
func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
counts, err := s.store.GetCountryStats(r.Context())
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
|
||||
counts, err := s.store.GetCountryStats(ctx)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to get country stats", "err", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
@@ -310,7 +355,8 @@ func (s *Server) handleAPICountryStats(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
f := parseDashboardFilter(r)
|
||||
|
||||
stats, err := s.store.GetFilteredDashboardStats(ctx, f)
|
||||
@@ -363,7 +409,8 @@ func (s *Server) handleFragmentDashboardContent(w http.ResponseWriter, r *http.R
|
||||
}
|
||||
|
||||
func (s *Server) handleAPISessionEvents(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx, cancel := dbContext(r)
|
||||
defer cancel()
|
||||
sessionID := r.PathValue("id")
|
||||
|
||||
events, err := s.store.GetSessionEvents(ctx, sessionID)
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
var until = form.elements['until'].value;
|
||||
if (since) params.set('since', since);
|
||||
if (until) params.set('until', until);
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
return params.toString();
|
||||
}
|
||||
|
||||
@@ -158,9 +162,14 @@
|
||||
var svg = container.querySelector('svg');
|
||||
if (!svg) return;
|
||||
|
||||
var paths = svg.querySelectorAll('path[id]');
|
||||
paths.forEach(function(path) {
|
||||
var id = path.id.toLowerCase();
|
||||
// Remove SVG title to prevent browser native tooltip
|
||||
var svgTitle = svg.querySelector('title');
|
||||
if (svgTitle) svgTitle.remove();
|
||||
|
||||
// Select both <path id="xx"> and <g id="xx"> country elements
|
||||
var elements = svg.querySelectorAll('path[id], g[id]');
|
||||
elements.forEach(function(el) {
|
||||
var id = el.id.toLowerCase();
|
||||
if (id.charAt(0) === '_') return; // skip non-country paths
|
||||
|
||||
var count = lookup[id];
|
||||
@@ -169,26 +178,34 @@
|
||||
var r = Math.round(30 + intensity * 69); // 30 -> 99
|
||||
var g = Math.round(30 + intensity * 72); // 30 -> 102
|
||||
var b = Math.round(62 + intensity * 179); // 62 -> 241
|
||||
path.style.fill = 'rgb(' + r + ',' + g + ',' + b + ')';
|
||||
var color = 'rgb(' + r + ',' + g + ',' + b + ')';
|
||||
// For <g> elements, color child paths; for <path>, color directly
|
||||
if (el.tagName.toLowerCase() === 'g') {
|
||||
el.querySelectorAll('path').forEach(function(p) {
|
||||
p.style.fill = color;
|
||||
});
|
||||
} else {
|
||||
el.style.fill = color;
|
||||
}
|
||||
}
|
||||
|
||||
path.addEventListener('mouseenter', function(e) {
|
||||
el.addEventListener('mouseenter', function(e) {
|
||||
var cc = id.toUpperCase();
|
||||
var n = lookup[id] || 0;
|
||||
tooltip.textContent = cc + ': ' + n.toLocaleString() + ' attempts';
|
||||
tooltip.style.display = 'block';
|
||||
});
|
||||
|
||||
path.addEventListener('mousemove', function(e) {
|
||||
el.addEventListener('mousemove', function(e) {
|
||||
tooltip.style.left = (e.clientX + 12) + 'px';
|
||||
tooltip.style.top = (e.clientY - 10) + 'px';
|
||||
});
|
||||
|
||||
path.addEventListener('mouseleave', function() {
|
||||
el.addEventListener('mouseleave', function() {
|
||||
tooltip.style.display = 'none';
|
||||
});
|
||||
|
||||
path.addEventListener('click', function() {
|
||||
el.addEventListener('click', function() {
|
||||
var input = document.querySelector('#filter-form input[name="country"]');
|
||||
if (input) {
|
||||
input.value = id.toUpperCase();
|
||||
@@ -196,7 +213,7 @@
|
||||
}
|
||||
});
|
||||
|
||||
path.style.cursor = 'pointer';
|
||||
el.style.cursor = 'pointer';
|
||||
});
|
||||
}
|
||||
|
||||
@@ -215,33 +232,20 @@
|
||||
if (val) params.set(name, val);
|
||||
});
|
||||
|
||||
var humanScore = form.elements['human_score'];
|
||||
if (humanScore && humanScore.checked) params.set('human_score', '1');
|
||||
var sortBy = form.elements['sort'];
|
||||
if (sortBy && sortBy.value) params.set('sort', sortBy.value);
|
||||
|
||||
var target = document.getElementById('dashboard-content');
|
||||
if (target) {
|
||||
var url = '/fragments/dashboard-content?' + params.toString();
|
||||
htmx.ajax('GET', url, {target: '#dashboard-content', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
// Client-side filter for recent sessions table
|
||||
filterSessionsTable(form);
|
||||
}
|
||||
|
||||
function filterSessionsTable(form) {
|
||||
var ip = form.elements['ip'].value.toLowerCase();
|
||||
var country = form.elements['country'].value.toLowerCase();
|
||||
var username = form.elements['username'].value.toLowerCase();
|
||||
|
||||
var rows = document.querySelectorAll('#recent-sessions-table tbody tr');
|
||||
rows.forEach(function(row) {
|
||||
var cells = row.querySelectorAll('td');
|
||||
if (cells.length < 4) { row.style.display = ''; return; }
|
||||
|
||||
var show = true;
|
||||
if (ip && cells[1].textContent.toLowerCase().indexOf(ip) === -1) show = false;
|
||||
if (country && cells[2].textContent.toLowerCase().indexOf(country) === -1) show = false;
|
||||
if (username && cells[3].textContent.toLowerCase().indexOf(username) === -1) show = false;
|
||||
|
||||
row.style.display = show ? '' : 'none';
|
||||
});
|
||||
// Server-side filter for recent sessions table
|
||||
var sessionsUrl = '/fragments/recent-sessions?' + params.toString();
|
||||
htmx.ajax('GET', sessionsUrl, {target: '#recent-sessions-table tbody', swap: 'innerHTML'});
|
||||
}
|
||||
|
||||
window.clearFilters = function() {
|
||||
|
||||
@@ -56,6 +56,20 @@ func templateFuncMap() template.FuncMap {
|
||||
}
|
||||
return s
|
||||
},
|
||||
"formatBytes": func(b int64) string {
|
||||
const (
|
||||
kb = 1024
|
||||
mb = 1024 * kb
|
||||
)
|
||||
switch {
|
||||
case b >= mb:
|
||||
return fmt.Sprintf("%.1f MB", float64(b)/float64(mb))
|
||||
case b >= kb:
|
||||
return fmt.Sprintf("%.1f KB", float64(b)/float64(kb))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +81,7 @@ func loadTemplates() (*templateSet, error) {
|
||||
"templates/dashboard.html",
|
||||
"templates/fragments/stats.html",
|
||||
"templates/fragments/active_sessions.html",
|
||||
"templates/fragments/recent_sessions.html",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing dashboard templates: %w", err)
|
||||
|
||||
@@ -13,6 +13,10 @@
|
||||
<label>Country <input type="text" name="country" placeholder="CN" maxlength="2"></label>
|
||||
<label>Username <input type="text" name="username" placeholder="root"></label>
|
||||
</div>
|
||||
<div class="grid">
|
||||
<label><input type="checkbox" name="human_score" value="1"> Human score > 0</label>
|
||||
<label>Sort by <select name="sort"><option value="connected_at">Recent</option><option value="input_bytes">Input Bytes</option></select></label>
|
||||
</div>
|
||||
<button type="submit">Apply</button>
|
||||
<button type="button" class="secondary" onclick="clearFilters()">Clear</button>
|
||||
</form>
|
||||
@@ -61,25 +65,13 @@
|
||||
<th>Username</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
<th>Disconnected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{range .RecentSessions}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="8">No sessions</td></tr>
|
||||
{{end}}
|
||||
{{template "recent_sessions" .RecentSessions}}
|
||||
</tbody>
|
||||
</table>
|
||||
</section>
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
<th>Username</th>
|
||||
<th>Type</th>
|
||||
<th>Score</th>
|
||||
<th>Input</th>
|
||||
<th>Connected</th>
|
||||
</tr>
|
||||
</thead>
|
||||
@@ -20,10 +21,11 @@
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="7">No active sessions</td></tr>
|
||||
<tr><td colspan="8">No active sessions</td></tr>
|
||||
{{end}}
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
17
internal/web/templates/fragments/recent_sessions.html
Normal file
17
internal/web/templates/fragments/recent_sessions.html
Normal file
@@ -0,0 +1,17 @@
|
||||
{{define "recent_sessions"}}
|
||||
{{range .}}
|
||||
<tr>
|
||||
<td><a href="/sessions/{{.ID}}"><code>{{truncateID .ID}}</code></a>{{if gt .EventCount 0}} <mark>replay</mark>{{end}}</td>
|
||||
<td>{{.IP}}</td>
|
||||
<td>{{.Country}}</td>
|
||||
<td>{{.Username}}</td>
|
||||
<td>{{if .ExecCommand}}<mark>exec</mark>{{else}}{{.ShellName}}{{end}}</td>
|
||||
<td>{{if .HumanScore}}{{if gt (derefFloat .HumanScore) 0.6}}<mark>{{formatScore .HumanScore}}</mark>{{else}}{{formatScore .HumanScore}}{{end}}{{else}}-{{end}}</td>
|
||||
<td>{{formatBytes .InputBytes}}</td>
|
||||
<td>{{formatTime .ConnectedAt}}</td>
|
||||
<td>{{if .DisconnectedAt}}{{formatTime (derefTime .DisconnectedAt)}}{{else}}<mark>active</mark>{{end}}</td>
|
||||
</tr>
|
||||
{{else}}
|
||||
<tr><td colspan="9">No sessions</td></tr>
|
||||
{{end}}
|
||||
{{end}}
|
||||
@@ -38,7 +38,7 @@
|
||||
}
|
||||
#world-map svg { width: 100%; height: auto; }
|
||||
#world-map svg path { fill: #2a2a3e; stroke: #555; stroke-width: 0.5; transition: fill 0.2s; }
|
||||
#world-map svg path:hover { stroke: #fff; stroke-width: 1; }
|
||||
#world-map svg path:hover, #world-map svg g:hover path { stroke: #fff; stroke-width: 1; }
|
||||
nav h1 {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
//go:embed static/*
|
||||
@@ -47,6 +47,7 @@ func NewServer(store storage.Store, logger *slog.Logger, metricsHandler http.Han
|
||||
s.mux.HandleFunc("GET /fragments/stats", s.handleFragmentStats)
|
||||
s.mux.HandleFunc("GET /fragments/active-sessions", s.handleFragmentActiveSessions)
|
||||
s.mux.HandleFunc("GET /fragments/dashboard-content", s.handleFragmentDashboardContent)
|
||||
s.mux.HandleFunc("GET /fragments/recent-sessions", s.handleFragmentRecentSessions)
|
||||
|
||||
if metricsHandler != nil {
|
||||
h := metricsHandler
|
||||
|
||||
@@ -10,8 +10,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"git.t-juice.club/torjus/oubliette/internal/storage"
|
||||
"code.t-juice.club/torjus/oubliette/internal/metrics"
|
||||
"code.t-juice.club/torjus/oubliette/internal/storage"
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
@@ -54,6 +54,30 @@ func newSeededTestServer(t *testing.T) *Server {
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestDbContextNotCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
dbCtx, dbCancel := dbContext(req)
|
||||
defer dbCancel()
|
||||
|
||||
// Cancel the original request context.
|
||||
cancel()
|
||||
|
||||
// The DB context should still be usable.
|
||||
select {
|
||||
case <-dbCtx.Done():
|
||||
t.Fatal("dbContext should not be canceled when request context is canceled")
|
||||
default:
|
||||
}
|
||||
|
||||
// Verify the DB context has a deadline (from the timeout).
|
||||
if _, ok := dbCtx.Deadline(); !ok {
|
||||
t.Error("dbContext should have a deadline")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardHandler(t *testing.T) {
|
||||
t.Run("empty store", func(t *testing.T) {
|
||||
srv := newTestServer(t)
|
||||
|
||||
@@ -34,6 +34,16 @@ password = "admin"
|
||||
# password = "cisco"
|
||||
# shell = "cisco"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "irobot"
|
||||
# password = "roomba"
|
||||
# shell = "roomba"
|
||||
|
||||
# [[auth.static_credentials]]
|
||||
# username = "player"
|
||||
# password = "tetris"
|
||||
# shell = "tetris"
|
||||
|
||||
[storage]
|
||||
db_path = "oubliette.db"
|
||||
retention_days = 90
|
||||
@@ -75,6 +85,12 @@ hostname = "ubuntu-server"
|
||||
# db_name = "postgres"
|
||||
# pg_version = "15.4"
|
||||
|
||||
# [shell.roomba]
|
||||
# No configuration options currently.
|
||||
|
||||
# [shell.tetris]
|
||||
# difficulty = "normal" # "easy" (slower start), "normal" (standard), "hard" (start at level 5)
|
||||
|
||||
# [detection]
|
||||
# enabled = true
|
||||
# threshold = 0.6 # 0.0–1.0, sessions above this trigger notifications
|
||||
|
||||
Reference in New Issue
Block a user