diff --git a/nixprstatus/__main__.py b/nixprstatus/__main__.py index 8a1a347..5cde8f0 100644 --- a/nixprstatus/__main__.py +++ b/nixprstatus/__main__.py @@ -1,7 +1,6 @@ import typer import json from typing import Annotated -from rich.console import Console from nixprstatus.pr import pr_merge_status from nixprstatus.pr import commits_since from nixprstatus.watchlist import Watchlist @@ -24,26 +23,17 @@ def pr( branches: Annotated[ list[str] | None, typer.Option(help="Check specific branch") ] = None, + format: Annotated[ + OutputFormat, typer.Option(help="Output format") + ] = OutputFormat.CONSOLE, ): """Get merge status of pull request.""" - console = Console() - if branches: status = pr_merge_status(pr, branches) else: status = pr_merge_status(pr) - console.print(f"{status.title}\n", highlight=False) - merged = ":white_check_mark: merged" if status.merged else ":x: merged" - console.print(merged, highlight=False) - - for branch in status.branches: - output = ( - f":white_check_mark: {branch}" - if status.branches[branch] - else f":x: {branch}" - ) - console.print(output, highlight=False) + status.print(format=format) @app.command() @@ -69,17 +59,16 @@ def since( @watchlist_app.command() def list(watchlist: str | None = None, format: OutputFormat = OutputFormat.CONSOLE): """List PRs in watchlist.""" - wl = Watchlist.from_file() + wl = Watchlist.from_file(path=watchlist) wl.print(format=format) @watchlist_app.command() -def add(pr: int): +def add(pr: int, watchlist: str | None = None): """Add PR to watchlist.""" - wl = Watchlist.from_file() - wl.add_pr(pr) - wl.to_file() - info = wl.pr(pr) + wl = Watchlist.from_file(path=watchlist) + info = wl.add_pr(pr) + wl.to_file(path=watchlist) print(f"Added #{info.pr}: {info.title} to watchlist.") @@ -87,6 +76,9 @@ def add(pr: int): def remove(pr: int): """Remove PR from watchlist.""" wl = Watchlist.from_file() + if pr not in wl: + print(f"#{pr} not in watchlist.") + return wl.remove(pr) wl.to_file() print(f"Removed #{pr} from watchlist.") diff --git a/nixprstatus/output.py b/nixprstatus/output.py new file mode 100644 index 0000000..949b4e4 --- /dev/null +++ b/nixprstatus/output.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class OutputFormat(str, Enum): + CONSOLE = "console" + JSON = "json" diff --git a/nixprstatus/pr.py b/nixprstatus/pr.py index f7988d8..1211d57 100644 --- a/nixprstatus/pr.py +++ b/nixprstatus/pr.py @@ -1,5 +1,8 @@ import requests from pydantic import BaseModel +from rich.console import Console + +from nixprstatus.output import OutputFormat DEFAULT_HEADERS = { "Accept": "application/vnd.github.text+json", @@ -13,6 +16,26 @@ class PRStatus(BaseModel): merged: bool branches: dict[str, bool] + def print(self, format: OutputFormat = OutputFormat.CONSOLE): + match format: + case OutputFormat.JSON: + print(self.model_dump_json()) + case OutputFormat.CONSOLE: + console = Console(highlight=False) + console.print(f"{self.title}\n") + merged = ":white_check_mark: merged" if self.merged else ":x: merged" + console.print(merged) + + for branch in self.branches: + output = ( + f":white_check_mark: {branch}" + if self.branches[branch] + else f":x: {branch}" + ) + console.print(output) + case _: + raise ValueError(f"Unknown format: {format}") + def commit_in_branch(commit_sha: str, branch: str) -> bool: url = f"https://api.github.com/repos/NixOS/nixpkgs/compare/{branch}...{commit_sha}" diff --git a/nixprstatus/watchlist.py b/nixprstatus/watchlist.py index 31a175a..facf029 100644 --- a/nixprstatus/watchlist.py +++ b/nixprstatus/watchlist.py @@ -1,16 +1,11 @@ import json import os -from enum import Enum from pathlib import Path from pydantic import BaseModel from rich.console import Console from nixprstatus.pr import get_pr - - -class OutputFormat(str, Enum): - CONSOLE = "console" - JSON = "json" +from nixprstatus.output import OutputFormat class PRInfo(BaseModel): @@ -45,12 +40,14 @@ class Watchlist(BaseModel): with open(p, "w") as f: f.write(self.model_dump_json()) - def add_pr(self, pr: int): + def add_pr(self, pr: int) -> PRInfo: # Lookup PR info info = get_pr(pr) title = info["title"] - self.prs.append(PRInfo(pr=pr, title=title)) + info = PRInfo(pr=pr, title=title) + self.prs.append(info) + return info def remove(self, pr: int): self.prs = [p for p in self.prs if p.pr != pr] @@ -72,6 +69,13 @@ class Watchlist(BaseModel): return p return None + def __contains__(self, item: PRInfo | int): + match item: + case PRInfo(): + return any([x == item for x in self.prs]) + case int(): + return any([x.pr == item for x in self.prs]) + def _default_path() -> str: if "XDG_STATE_HOME" in os.environ: @@ -79,6 +83,6 @@ def _default_path() -> str: return "~/.config/nixprstatus/watchlist.json" -def _ensure_default_path() -> str: +def _ensure_default_path(): p = Path(_default_path()).expanduser() p.parent.mkdir(parents=True, exist_ok=True) diff --git a/pyproject.toml b/pyproject.toml index e8c90f9..cacab20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nixprstatus" -version = "0.1.4" +version = "0.1.5" description = "Nixpkgs PR status checker" authors = ["Torjus HÃ¥kestad "] license = "MIT" diff --git a/tests/test_watchlist.py b/tests/test_watchlist.py index b4db857..27672bb 100644 --- a/tests/test_watchlist.py +++ b/tests/test_watchlist.py @@ -32,3 +32,9 @@ class TestWatchlist(unittest.TestCase): w = Watchlist(prs=[PRInfo(pr=1, title="PR 1")]) self.assertEqual(w.pr(1), PRInfo(pr=1, title="PR 1")) self.assertEqual(w.pr(2), None) + + def test_contains(self): + w = Watchlist(prs=[PRInfo(pr=1, title="PR 1")]) + self.assertIn(PRInfo(pr=1, title="PR 1"), w) + self.assertIn(1, w) + self.assertNotIn(2, w)