From 06abde6f6f6569983af75e4fba503ca6ce75ba36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torjus=20H=C3=A5kestad?= Date: Wed, 9 Oct 2024 22:21:28 +0200 Subject: [PATCH] Add watchlist command --- nixprstatus/__main__.py | 26 +++++++++++++ nixprstatus/pr.py | 8 ++++ nixprstatus/watchlist.py | 83 ++++++++++++++++++++++++++++++++++++++++ tests/__init__.py | 0 tests/helpers/mocks.py | 32 ++++++++++++++++ tests/test_pr.py | 33 +--------------- tests/test_watchlist.py | 34 ++++++++++++++++ 7 files changed, 184 insertions(+), 32 deletions(-) create mode 100644 nixprstatus/watchlist.py create mode 100644 tests/__init__.py create mode 100644 tests/helpers/mocks.py create mode 100644 tests/test_watchlist.py diff --git a/nixprstatus/__main__.py b/nixprstatus/__main__.py index b3c88fa..1ff7154 100644 --- a/nixprstatus/__main__.py +++ b/nixprstatus/__main__.py @@ -4,8 +4,12 @@ from typing import Annotated from rich.console import Console from nixprstatus.pr import pr_merge_status from nixprstatus.pr import commits_since +from nixprstatus.watchlist import Watchlist +from nixprstatus.watchlist import OutputFormat app = typer.Typer(rich_markup_mode=None) +watchlist_app = typer.Typer() +app.add_typer(watchlist_app, name="watchlist", help="Manage watchlist.") DEFAULT_HEADERS = { "Accept": "application/vnd.github.text+json", @@ -61,6 +65,28 @@ def since( return typer.echo(count) +@watchlist_app.command() +def list(watchlist: str|None = None, format: OutputFormat = OutputFormat.CONSOLE): + """List PRs in watchlist.""" + wl = Watchlist.from_file() + wl.print(format=format) + +@watchlist_app.command() +def add(pr: int): + """Add PR to watchlist.""" + wl = Watchlist.from_file() + wl.add_pr(pr) + wl.to_file() + info = wl.pr(pr) + print(f"Added #{info.pr}: {info.title} to watchlist.") + +@watchlist_app.command() +def remove(pr: int): + """Remove PR from watchlist.""" + wl = Watchlist.from_file() + wl.remove(pr) + wl.to_file() + print(f"Removed #{pr} from watchlist.") def main(): app() diff --git a/nixprstatus/pr.py b/nixprstatus/pr.py index 6eb0bc5..f7988d8 100644 --- a/nixprstatus/pr.py +++ b/nixprstatus/pr.py @@ -33,6 +33,14 @@ def commits_since(first_ref: str, last_ref: str) -> int: return commit_response.json()["behind_by"] +def get_pr(pr: int) -> dict: + url = f"https://api.github.com/repos/NixOS/nixpkgs/pulls/{pr}" + pr_response = requests.get(url, headers=DEFAULT_HEADERS) + pr_response.raise_for_status() + + return pr_response.json() + + def pr_merge_status( pr: int, branches: list[str] = DEFAULT_BRANCHES, check_backport: bool = True ) -> PRStatus: diff --git a/nixprstatus/watchlist.py b/nixprstatus/watchlist.py new file mode 100644 index 0000000..88b504f --- /dev/null +++ b/nixprstatus/watchlist.py @@ -0,0 +1,83 @@ +import json +import os +from enum import Enum +from pathlib import Path +from pydantic import BaseModel +from rich.console import Console + +from nixprstatus.pr import get_pr + +class OutputFormat(str, Enum): + CONSOLE = "console" + JSON = "json" + + +class PRInfo(BaseModel): + pr: int + title: str + + +class Watchlist(BaseModel): + prs: list[PRInfo] + + @classmethod + def from_file(cls, path: str | None = None) -> "Watchlist": + if not path: + path = _default_path() + + p = Path(path).expanduser() + + if not p.exists(): + return cls(prs=[]) + + with open(p, "r") as f: + data = json.load(f) + return cls(**data) + + def to_file(self, path: str | None = None): + if not path: + _ensure_default_path() + path = _default_path() + + p = Path(path).expanduser() + + with open(p, "w") as f: + f.write(self.model_dump_json()) + + def add_pr(self, pr: int): + # Lookup PR info + info = get_pr(pr) + + title = info["title"] + self.prs.append(PRInfo(pr=pr, title=title)) + + def remove(self, pr: int): + self.prs = [p for p in self.prs if p.pr != pr] + + def print(self, format: OutputFormat = OutputFormat.CONSOLE): + match format: + case OutputFormat.CONSOLE: + console = Console() + for pr in self.prs: + console.print(f"{pr.pr}: {pr.title}") + case OutputFormat.JSON: + print(self.model_dump_json()) + case _: + raise ValueError(f"Unknown format: {format}") + + def pr(self, pr: int) -> PRInfo | None: + for p in self.prs: + if p.pr == pr: + return p + return None + + +def _default_path() -> str: + if "XDG_STATE_HOME" in os.environ: + return f"{os.environ['XDG_STATE_HOME']}/nixprstatus/watchlist.json" + return "~/.config/nixprstatus/watchlist.json" + + +def _ensure_default_path() -> str: + p = Path(_default_path()).expanduser() + p.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/helpers/mocks.py b/tests/helpers/mocks.py new file mode 100644 index 0000000..be4b28e --- /dev/null +++ b/tests/helpers/mocks.py @@ -0,0 +1,32 @@ +import requests +import json + + +def mocked_requests_get(*args, **kwargs): + class MockedResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return json.loads(self.json_data) + + def raise_for_status(self): + if self.status_code not in [200, 201]: + raise requests.exceptions.HTTPError() + + if "pulls" in args[0]: + pr = args[0].split("/")[-1] + with open(f"tests/fixtures/pulls_{pr}.json") as f: + data = f.read() + return MockedResponse(data, 200) + elif "compare" in args[0]: + branch, commit_sha = args[0].split("/")[-1].split("...") + with open(f"tests/fixtures/compare_{branch}_{commit_sha}.json") as f: + data = f.read() + return MockedResponse(data, 200) + elif "comments" in args[0]: + pr = args[0].split("/")[-2] + with open(f"tests/fixtures/comments_{pr}.json") as f: + data = f.read() + return MockedResponse(data, 200) diff --git a/tests/test_pr.py b/tests/test_pr.py index 27ab5ce..de44523 100644 --- a/tests/test_pr.py +++ b/tests/test_pr.py @@ -1,39 +1,8 @@ import unittest import unittest.mock -import requests -import json from nixprstatus.pr import commit_in_branch, pr_merge_status, commits_since - - -def mocked_requests_get(*args, **kwargs): - class MockedResponse: - def __init__(self, json_data, status_code): - self.json_data = json_data - self.status_code = status_code - - def json(self): - return json.loads(self.json_data) - - def raise_for_status(self): - if self.status_code not in [200, 201]: - raise requests.exceptions.HTTPError() - - if "pulls" in args[0]: - pr = args[0].split("/")[-1] - with open(f"tests/fixtures/pulls_{pr}.json") as f: - data = f.read() - return MockedResponse(data, 200) - elif "compare" in args[0]: - branch, commit_sha = args[0].split("/")[-1].split("...") - with open(f"tests/fixtures/compare_{branch}_{commit_sha}.json") as f: - data = f.read() - return MockedResponse(data, 200) - elif "comments" in args[0]: - pr = args[0].split("/")[-2] - with open(f"tests/fixtures/comments_{pr}.json") as f: - data = f.read() - return MockedResponse(data, 200) +from tests.helpers.mocks import mocked_requests_get class TestPRMergeStatus(unittest.TestCase): diff --git a/tests/test_watchlist.py b/tests/test_watchlist.py new file mode 100644 index 0000000..d8c2168 --- /dev/null +++ b/tests/test_watchlist.py @@ -0,0 +1,34 @@ +from nixprstatus.watchlist import Watchlist, PRInfo +from tempfile import TemporaryDirectory +import unittest + +from tests.helpers.mocks import mocked_requests_get + + +class TestWatchlist(unittest.TestCase): + def test_save_load(self): + with TemporaryDirectory() as d: + filename = f"{d}/test.json" + + watchlist = Watchlist(prs=[PRInfo(pr=1, title="PR 1")]) + watchlist.to_file(filename) + + # Check that the file was written correctly + with open(filename, "r") as f: + self.assertEqual(watchlist.model_dump_json(), f.read()) + + # Check that the file can be read back + loaded = Watchlist.from_file(filename) + self.assertEqual(watchlist, loaded) + + @unittest.mock.patch("requests.get", side_effect=mocked_requests_get) + def test_add_pr(self, mock_get): + w = Watchlist(prs=[]) + w.add_pr(345583) + self.assertEqual(len(w.prs), 1) + self.assertEqual(w.prs[0].title, "wireshark: 4.2.6 -> 4.2.7") + + def test_get_pr(self): + w = Watchlist(prs=[PRInfo(pr=1, title="PR 1")]) + self.assertEqual(w.pr(1), PRInfo(pr=1, title="PR 1")) + self.assertEqual(w.pr(2), None) \ No newline at end of file