diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index 9aaf413..6819050 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "path/filepath" "strings" "time" @@ -177,10 +178,17 @@ func (s *Server) handleGetFile(ctx context.Context, args map[string]interface{}) return ErrorContent(fmt.Errorf("path is required")), nil } - // Security: validate path - if strings.Contains(path, "..") { + // Security: validate path to prevent traversal attacks + // Clean the path and check for dangerous patterns + cleanPath := filepath.Clean(path) + if filepath.IsAbs(cleanPath) { + return ErrorContent(fmt.Errorf("invalid path: absolute paths not allowed")), nil + } + if strings.HasPrefix(cleanPath, "..") { return ErrorContent(fmt.Errorf("invalid path: directory traversal not allowed")), nil } + // Use the cleaned path for lookup + path = cleanPath revision, _ := args["revision"].(string) rev, err := s.resolveRevision(ctx, revision) diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 7a49d2a..71cd5bb 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -9,6 +9,7 @@ import ( "testing" "git.t-juice.club/torjus/labmcp/internal/database" + "git.t-juice.club/torjus/labmcp/internal/nixos" ) func TestServerInitialize(t *testing.T) { @@ -144,6 +145,91 @@ func TestServerNotification(t *testing.T) { } } +func TestGetFilePathValidation(t *testing.T) { + store := setupTestStore(t) + server := setupTestServer(t, store) + + // Create a test revision and file + ctx := context.Background() + rev := &database.Revision{ + GitHash: "abc123", + OptionCount: 0, + } + if err := store.CreateRevision(ctx, rev); err != nil { + t.Fatalf("Failed to create revision: %v", err) + } + + file := &database.File{ + RevisionID: rev.ID, + FilePath: "nixos/modules/test.nix", + Extension: ".nix", + Content: "{ }", + } + if err := store.CreateFile(ctx, file); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + + tests := []struct { + name string + path string + wantError bool + errorMsg string + }{ + // Valid paths + {"valid relative path", "nixos/modules/test.nix", false, ""}, + {"valid simple path", "test.nix", false, ""}, + + // Path traversal attempts + {"dotdot traversal", "../etc/passwd", true, "directory traversal"}, + {"dotdot in middle", "nixos/../../../etc/passwd", true, "directory traversal"}, + {"multiple dotdot", "../../etc/passwd", true, "directory traversal"}, + + // Absolute paths + {"absolute unix path", "/etc/passwd", true, "absolute paths"}, + + // Cleaned paths that become traversal + {"dot slash dotdot", "./../../etc/passwd", true, "directory traversal"}, + + // Paths that clean to valid (no error expected, but file won't exist) + {"dotdot at end cleans to valid", "nixos/modules/..", false, ""}, // Cleans to "nixos", which is safe + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := `{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"get_file","arguments":{"path":"` + tt.path + `","revision":"abc123"}}}` + resp := runRequest(t, server, input) + + result, ok := resp.Result.(map[string]interface{}) + if !ok { + t.Fatalf("Expected map result, got %T", resp.Result) + } + + isError, _ := result["isError"].(bool) + + if tt.wantError { + if !isError { + t.Errorf("Expected error for path %q, got success", tt.path) + } else { + content := result["content"].([]interface{}) + text := content[0].(map[string]interface{})["text"].(string) + if tt.errorMsg != "" && !strings.Contains(text, tt.errorMsg) { + t.Errorf("Error message %q doesn't contain %q", text, tt.errorMsg) + } + } + } else { + // For valid paths that don't exist, we expect a "not found" error, not a security error + if isError { + content := result["content"].([]interface{}) + text := content[0].(map[string]interface{})["text"].(string) + if strings.Contains(text, "traversal") || strings.Contains(text, "absolute") { + t.Errorf("Got security error for valid path %q: %s", tt.path, text) + } + } + } + }) + } +} + // Helper functions func setupTestStore(t *testing.T) database.Store { @@ -165,6 +251,16 @@ func setupTestStore(t *testing.T) database.Store { return store } +func setupTestServer(t *testing.T, store database.Store) *Server { + t.Helper() + + server := NewServer(store, nil) + indexer := nixos.NewIndexer(store) + server.RegisterHandlers(indexer) + + return server +} + func runRequest(t *testing.T, server *Server, input string) *Response { t.Helper()