package bash import ( "bytes" "context" "errors" "io" "strings" "testing" "time" "git.t-juice.club/torjus/oubliette/internal/shell" "git.t-juice.club/torjus/oubliette/internal/storage" ) type rwCloser struct { io.Reader io.Writer closed bool } func (r *rwCloser) Close() error { r.closed = true return nil } func TestFormatPrompt(t *testing.T) { tests := []struct { cwd string want string }{ {"/root", "root@host:~# "}, {"/root/sub", "root@host:~/sub# "}, {"/tmp", "root@host:/tmp# "}, {"/", "root@host:/# "}, } for _, tt := range tests { state := &shellState{cwd: tt.cwd, username: "root", hostname: "host"} got := formatPrompt(state) if got != tt.want { t.Errorf("formatPrompt(cwd=%q) = %q, want %q", tt.cwd, got, tt.want) } } } func TestReadLineEnter(t *testing.T) { input := bytes.NewBufferString("hello\r") var output bytes.Buffer rw := struct { io.Reader io.Writer }{input, &output} ctx := context.Background() line, err := shell.ReadLine(ctx, rw) if err != nil { t.Fatalf("readLine: %v", err) } if line != "hello" { t.Errorf("line = %q, want %q", line, "hello") } } func TestReadLineBackspace(t *testing.T) { // Type "helo", backspace, then "lo\r" input := bytes.NewBuffer([]byte{'h', 'e', 'l', 'o', 127, 'l', 'o', '\r'}) var output bytes.Buffer rw := struct { io.Reader io.Writer }{input, &output} ctx := context.Background() line, err := shell.ReadLine(ctx, rw) if err != nil { t.Fatalf("readLine: %v", err) } if line != "hello" { t.Errorf("line = %q, want %q", line, "hello") } } func TestReadLineCtrlC(t *testing.T) { input := bytes.NewBuffer([]byte("partial\x03")) var output bytes.Buffer rw := struct { io.Reader io.Writer }{input, &output} ctx := context.Background() line, err := shell.ReadLine(ctx, rw) if err != nil { t.Fatalf("readLine: %v", err) } if line != "" { t.Errorf("line after Ctrl+C = %q, want empty", line) } } func TestReadLineCtrlD(t *testing.T) { input := bytes.NewBuffer([]byte{4}) // Ctrl+D on empty line var output bytes.Buffer rw := struct { io.Reader io.Writer }{input, &output} ctx := context.Background() _, err := shell.ReadLine(ctx, rw) if !errors.Is(err, io.EOF) { t.Fatalf("expected io.EOF, got %v", err) } } func TestBashShellHandle(t *testing.T) { store := storage.NewMemoryStore() sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "root", "bash") sess := &shell.SessionContext{ SessionID: sessID, Username: "root", Store: store, CommonConfig: shell.ShellCommonConfig{ Hostname: "testhost", Banner: "Welcome to Ubuntu 22.04.3 LTS\r\n\r\n", }, } // Simulate typing commands followed by "exit\r" commands := "pwd\rwhoami\rexit\r" clientInput := bytes.NewBufferString(commands) var clientOutput bytes.Buffer rw := &rwCloser{ Reader: clientInput, Writer: &clientOutput, } sh := NewBashShell() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := sh.Handle(ctx, sess, rw) if err != nil { t.Fatalf("Handle: %v", err) } output := clientOutput.String() // Should contain banner. if !strings.Contains(output, "Welcome to Ubuntu") { t.Error("output should contain banner") } // Should contain prompt with hostname. if !strings.Contains(output, "root@testhost") { t.Errorf("output should contain prompt, got: %s", output) } // Check session logs were recorded. if len(store.SessionLogs) < 2 { t.Errorf("expected at least 2 session logs, got %d", len(store.SessionLogs)) } } func TestBashShellFakeUser(t *testing.T) { store := storage.NewMemoryStore() sessID, _ := store.CreateSession(context.Background(), "127.0.0.1", "attacker", "bash") sess := &shell.SessionContext{ SessionID: sessID, Username: "attacker", Store: store, CommonConfig: shell.ShellCommonConfig{ Hostname: "testhost", FakeUser: "admin", }, } commands := "whoami\rexit\r" clientInput := bytes.NewBufferString(commands) var clientOutput bytes.Buffer rw := &rwCloser{ Reader: clientInput, Writer: &clientOutput, } sh := NewBashShell() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() sh.Handle(ctx, sess, rw) output := clientOutput.String() if !strings.Contains(output, "admin") { t.Errorf("output should contain fake user 'admin', got: %s", output) } }