package mcp import ( "context" "encoding/json" "fmt" "path/filepath" "strings" "time" "git.t-juice.club/torjus/labmcp/internal/database" "git.t-juice.club/torjus/labmcp/internal/options" ) // RegisterHandlers registers all tool handlers on the server. func (s *Server) RegisterHandlers(indexer options.Indexer) { s.tools["search_options"] = s.handleSearchOptions s.tools["get_option"] = s.handleGetOption s.tools["get_file"] = s.handleGetFile s.tools["index_revision"] = s.makeIndexHandler(indexer) s.tools["list_revisions"] = s.handleListRevisions s.tools["delete_revision"] = s.handleDeleteRevision } // handleSearchOptions handles the search_options tool. func (s *Server) handleSearchOptions(ctx context.Context, args map[string]interface{}) (CallToolResult, error) { query, _ := args["query"].(string) if query == "" { return ErrorContent(fmt.Errorf("query is required")), nil } revision, _ := args["revision"].(string) rev, err := s.resolveRevision(ctx, revision) if err != nil { return ErrorContent(err), nil } if rev == nil { return ErrorContent(fmt.Errorf("no indexed revision available")), nil } filters := database.SearchFilters{ Limit: 50, } if t, ok := args["type"].(string); ok && t != "" { filters.Type = t } if ns, ok := args["namespace"].(string); ok && ns != "" { filters.Namespace = ns } if limit, ok := args["limit"].(float64); ok && limit > 0 { filters.Limit = int(limit) } options, err := s.store.SearchOptions(ctx, rev.ID, query, filters) if err != nil { return ErrorContent(fmt.Errorf("search failed: %w", err)), nil } // Format results var sb strings.Builder sb.WriteString(fmt.Sprintf("Found %d options matching '%s' in revision %s:\n\n", len(options), query, rev.GitHash[:8])) for _, opt := range options { sb.WriteString(fmt.Sprintf("## %s\n", opt.Name)) sb.WriteString(fmt.Sprintf("Type: %s\n", opt.Type)) if opt.Description != "" { desc := opt.Description if len(desc) > 200 { desc = desc[:200] + "..." } sb.WriteString(fmt.Sprintf("Description: %s\n", desc)) } sb.WriteString("\n") } return CallToolResult{ Content: []Content{TextContent(sb.String())}, }, nil } // handleGetOption handles the get_option tool. func (s *Server) handleGetOption(ctx context.Context, args map[string]interface{}) (CallToolResult, error) { name, _ := args["name"].(string) if name == "" { return ErrorContent(fmt.Errorf("name is required")), nil } revision, _ := args["revision"].(string) rev, err := s.resolveRevision(ctx, revision) if err != nil { return ErrorContent(err), nil } if rev == nil { return ErrorContent(fmt.Errorf("no indexed revision available")), nil } option, err := s.store.GetOption(ctx, rev.ID, name) if err != nil { return ErrorContent(fmt.Errorf("failed to get option: %w", err)), nil } if option == nil { return ErrorContent(fmt.Errorf("option '%s' not found", name)), nil } // Get declarations with file metadata declarations, err := s.store.GetDeclarationsWithMetadata(ctx, rev.ID, option.ID) if err != nil { s.logger.Printf("Failed to get declarations: %v", err) } // Format result var sb strings.Builder sb.WriteString(fmt.Sprintf("# %s\n\n", option.Name)) sb.WriteString(fmt.Sprintf("**Type:** %s\n", option.Type)) if option.Description != "" { sb.WriteString(fmt.Sprintf("\n**Description:**\n%s\n", option.Description)) } if option.DefaultValue != "" && option.DefaultValue != "null" { sb.WriteString(fmt.Sprintf("\n**Default:** `%s`\n", formatJSON(option.DefaultValue))) } if option.Example != "" && option.Example != "null" { sb.WriteString(fmt.Sprintf("\n**Example:** `%s`\n", formatJSON(option.Example))) } if option.ReadOnly { sb.WriteString("\n**Read-only:** Yes\n") } if len(declarations) > 0 { sb.WriteString("\n**Declared in:**\n") for _, decl := range declarations { if decl.Line > 0 { sb.WriteString(fmt.Sprintf("- %s:%d", decl.FilePath, decl.Line)) } else { sb.WriteString(fmt.Sprintf("- %s", decl.FilePath)) } // Add file metadata if available if decl.HasFile && decl.ByteSize > 0 { sb.WriteString(fmt.Sprintf(" (%d bytes, %d lines)", decl.ByteSize, decl.LineCount)) } sb.WriteString("\n") } } // Include children if requested (default: true) includeChildren := true if ic, ok := args["include_children"].(bool); ok { includeChildren = ic } if includeChildren { children, err := s.store.GetChildren(ctx, rev.ID, option.Name) if err != nil { s.logger.Printf("Failed to get children: %v", err) } if len(children) > 0 { sb.WriteString("\n**Sub-options:**\n") for _, child := range children { // Show just the last part of the name shortName := child.Name if strings.HasPrefix(child.Name, option.Name+".") { shortName = child.Name[len(option.Name)+1:] } sb.WriteString(fmt.Sprintf("- `%s` (%s)\n", shortName, child.Type)) } } } return CallToolResult{ Content: []Content{TextContent(sb.String())}, }, nil } // handleGetFile handles the get_file tool. func (s *Server) handleGetFile(ctx context.Context, args map[string]interface{}) (CallToolResult, error) { path, _ := args["path"].(string) if path == "" { return ErrorContent(fmt.Errorf("path is required")), nil } // Security: validate path to prevent traversal attacks // Clean the path and check for dangerous patterns cleanPath := filepath.Clean(path) if filepath.IsAbs(cleanPath) { return ErrorContent(fmt.Errorf("invalid path: absolute paths not allowed")), nil } if strings.HasPrefix(cleanPath, "..") { return ErrorContent(fmt.Errorf("invalid path: directory traversal not allowed")), nil } // Use the cleaned path for lookup path = cleanPath revision, _ := args["revision"].(string) rev, err := s.resolveRevision(ctx, revision) if err != nil { return ErrorContent(err), nil } if rev == nil { return ErrorContent(fmt.Errorf("no indexed revision available")), nil } // Parse range parameters var offset, limit int if o, ok := args["offset"].(float64); ok { offset = int(o) } if l, ok := args["limit"].(float64); ok { limit = int(l) } // Use GetFileWithRange fileRange := database.FileRange{Offset: offset, Limit: limit} result, err := s.store.GetFileWithRange(ctx, rev.ID, path, fileRange) if err != nil { return ErrorContent(fmt.Errorf("failed to get file: %w", err)), nil } if result == nil { return ErrorContent(fmt.Errorf("file '%s' not found (files may not be indexed for this revision)", path)), nil } // Format output with range metadata var sb strings.Builder if result.TotalLines > 0 && (result.StartLine > 1 || result.EndLine < result.TotalLines) { sb.WriteString(fmt.Sprintf("Showing lines %d-%d of %d total\n\n", result.StartLine, result.EndLine, result.TotalLines)) } sb.WriteString(fmt.Sprintf("```%s\n%s\n```", strings.TrimPrefix(result.Extension, "."), result.Content)) return CallToolResult{ Content: []Content{TextContent(sb.String())}, }, nil } // makeIndexHandler creates the index_revision handler with the indexer. func (s *Server) makeIndexHandler(indexer options.Indexer) ToolHandler { return func(ctx context.Context, args map[string]interface{}) (CallToolResult, error) { revision, _ := args["revision"].(string) if revision == "" { return ErrorContent(fmt.Errorf("revision is required")), nil } result, err := indexer.IndexRevision(ctx, revision) if err != nil { return ErrorContent(fmt.Errorf("indexing failed: %w", err)), nil } // If already indexed, return early with info if result.AlreadyIndexed { var sb strings.Builder sb.WriteString(fmt.Sprintf("Revision already indexed: %s\n", result.Revision.GitHash)) if result.Revision.ChannelName != "" { sb.WriteString(fmt.Sprintf("Channel: %s\n", result.Revision.ChannelName)) } sb.WriteString(fmt.Sprintf("Options: %d\n", result.OptionCount)) sb.WriteString(fmt.Sprintf("Indexed at: %s\n", result.Revision.IndexedAt.Format("2006-01-02 15:04"))) return CallToolResult{ Content: []Content{TextContent(sb.String())}, }, nil } // Index files by default fileCount, err := indexer.IndexFiles(ctx, result.Revision.ID, result.Revision.GitHash) if err != nil { s.logger.Printf("Warning: file indexing failed: %v", err) } var sb strings.Builder sb.WriteString(fmt.Sprintf("Indexed revision: %s\n", result.Revision.GitHash)) if result.Revision.ChannelName != "" { sb.WriteString(fmt.Sprintf("Channel: %s\n", result.Revision.ChannelName)) } sb.WriteString(fmt.Sprintf("Options: %d\n", result.OptionCount)) sb.WriteString(fmt.Sprintf("Files: %d\n", fileCount)) // Handle Duration which may be time.Duration or interface{} if dur, ok := result.Duration.(time.Duration); ok { sb.WriteString(fmt.Sprintf("Duration: %s\n", dur.Round(time.Millisecond))) } return CallToolResult{ Content: []Content{TextContent(sb.String())}, }, nil } } // handleListRevisions handles the list_revisions tool. func (s *Server) handleListRevisions(ctx context.Context, args map[string]interface{}) (CallToolResult, error) { revisions, err := s.store.ListRevisions(ctx) if err != nil { return ErrorContent(fmt.Errorf("failed to list revisions: %w", err)), nil } if len(revisions) == 0 { return CallToolResult{ Content: []Content{TextContent("No revisions indexed. Use index_revision to index a nixpkgs version.")}, }, nil } var sb strings.Builder sb.WriteString(fmt.Sprintf("Indexed revisions (%d):\n\n", len(revisions))) for _, rev := range revisions { sb.WriteString(fmt.Sprintf("- **%s**", rev.GitHash[:12])) if rev.ChannelName != "" { sb.WriteString(fmt.Sprintf(" (%s)", rev.ChannelName)) } sb.WriteString(fmt.Sprintf("\n Options: %d, Indexed: %s\n", rev.OptionCount, rev.IndexedAt.Format("2006-01-02 15:04"))) } return CallToolResult{ Content: []Content{TextContent(sb.String())}, }, nil } // handleDeleteRevision handles the delete_revision tool. func (s *Server) handleDeleteRevision(ctx context.Context, args map[string]interface{}) (CallToolResult, error) { revision, _ := args["revision"].(string) if revision == "" { return ErrorContent(fmt.Errorf("revision is required")), nil } rev, err := s.resolveRevision(ctx, revision) if err != nil { return ErrorContent(err), nil } if rev == nil { return ErrorContent(fmt.Errorf("revision '%s' not found", revision)), nil } if err := s.store.DeleteRevision(ctx, rev.ID); err != nil { return ErrorContent(fmt.Errorf("failed to delete revision: %w", err)), nil } return CallToolResult{ Content: []Content{TextContent(fmt.Sprintf("Deleted revision %s", rev.GitHash))}, }, nil } // resolveRevision resolves a revision string to a Revision object. func (s *Server) resolveRevision(ctx context.Context, revision string) (*database.Revision, error) { if revision == "" { // Try to find a default revision using config defaultChannel := s.config.DefaultChannel if defaultChannel == "" { defaultChannel = "nixos-stable" // fallback for backwards compatibility } rev, err := s.store.GetRevisionByChannel(ctx, defaultChannel) if err != nil { return nil, err } if rev != nil { return rev, nil } // Fall back to any available revision revs, err := s.store.ListRevisions(ctx) if err != nil { return nil, err } if len(revs) > 0 { return revs[0], nil } return nil, nil } // Try by git hash first rev, err := s.store.GetRevision(ctx, revision) if err != nil { return nil, err } if rev != nil { return rev, nil } // Try by channel name rev, err = s.store.GetRevisionByChannel(ctx, revision) if err != nil { return nil, err } return rev, nil } // formatJSON formats a JSON string for display, handling compact representation. func formatJSON(s string) string { if s == "" || s == "null" { return s } // Try to parse and reformat var v interface{} if err := json.Unmarshal([]byte(s), &v); err != nil { return s } // For simple values, return as-is switch val := v.(type) { case bool, float64, string: return s case []interface{}: if len(val) <= 3 { return s } case map[string]interface{}: if len(val) <= 3 { return s } } // For complex values, try to pretty print (truncated) b, err := json.MarshalIndent(v, "", " ") if err != nil { return s } result := string(b) if len(result) > 500 { result = result[:500] + "..." } return result }