Add watchlist command
This commit is contained in:
parent
99e0887505
commit
06abde6f6f
@ -4,8 +4,12 @@ from typing import Annotated
|
||||
from rich.console import Console
|
||||
from nixprstatus.pr import pr_merge_status
|
||||
from nixprstatus.pr import commits_since
|
||||
from nixprstatus.watchlist import Watchlist
|
||||
from nixprstatus.watchlist import OutputFormat
|
||||
|
||||
app = typer.Typer(rich_markup_mode=None)
|
||||
watchlist_app = typer.Typer()
|
||||
app.add_typer(watchlist_app, name="watchlist", help="Manage watchlist.")
|
||||
|
||||
DEFAULT_HEADERS = {
|
||||
"Accept": "application/vnd.github.text+json",
|
||||
@ -61,6 +65,28 @@ def since(
|
||||
return
|
||||
typer.echo(count)
|
||||
|
||||
@watchlist_app.command()
|
||||
def list(watchlist: str|None = None, format: OutputFormat = OutputFormat.CONSOLE):
|
||||
"""List PRs in watchlist."""
|
||||
wl = Watchlist.from_file()
|
||||
wl.print(format=format)
|
||||
|
||||
@watchlist_app.command()
|
||||
def add(pr: int):
|
||||
"""Add PR to watchlist."""
|
||||
wl = Watchlist.from_file()
|
||||
wl.add_pr(pr)
|
||||
wl.to_file()
|
||||
info = wl.pr(pr)
|
||||
print(f"Added #{info.pr}: {info.title} to watchlist.")
|
||||
|
||||
@watchlist_app.command()
|
||||
def remove(pr: int):
|
||||
"""Remove PR from watchlist."""
|
||||
wl = Watchlist.from_file()
|
||||
wl.remove(pr)
|
||||
wl.to_file()
|
||||
print(f"Removed #{pr} from watchlist.")
|
||||
|
||||
def main():
|
||||
app()
|
||||
|
@ -33,6 +33,14 @@ def commits_since(first_ref: str, last_ref: str) -> int:
|
||||
return commit_response.json()["behind_by"]
|
||||
|
||||
|
||||
def get_pr(pr: int) -> dict:
|
||||
url = f"https://api.github.com/repos/NixOS/nixpkgs/pulls/{pr}"
|
||||
pr_response = requests.get(url, headers=DEFAULT_HEADERS)
|
||||
pr_response.raise_for_status()
|
||||
|
||||
return pr_response.json()
|
||||
|
||||
|
||||
def pr_merge_status(
|
||||
pr: int, branches: list[str] = DEFAULT_BRANCHES, check_backport: bool = True
|
||||
) -> PRStatus:
|
||||
|
83
nixprstatus/watchlist.py
Normal file
83
nixprstatus/watchlist.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel
|
||||
from rich.console import Console
|
||||
|
||||
from nixprstatus.pr import get_pr
|
||||
|
||||
class OutputFormat(str, Enum):
|
||||
CONSOLE = "console"
|
||||
JSON = "json"
|
||||
|
||||
|
||||
class PRInfo(BaseModel):
|
||||
pr: int
|
||||
title: str
|
||||
|
||||
|
||||
class Watchlist(BaseModel):
|
||||
prs: list[PRInfo]
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: str | None = None) -> "Watchlist":
|
||||
if not path:
|
||||
path = _default_path()
|
||||
|
||||
p = Path(path).expanduser()
|
||||
|
||||
if not p.exists():
|
||||
return cls(prs=[])
|
||||
|
||||
with open(p, "r") as f:
|
||||
data = json.load(f)
|
||||
return cls(**data)
|
||||
|
||||
def to_file(self, path: str | None = None):
|
||||
if not path:
|
||||
_ensure_default_path()
|
||||
path = _default_path()
|
||||
|
||||
p = Path(path).expanduser()
|
||||
|
||||
with open(p, "w") as f:
|
||||
f.write(self.model_dump_json())
|
||||
|
||||
def add_pr(self, pr: int):
|
||||
# Lookup PR info
|
||||
info = get_pr(pr)
|
||||
|
||||
title = info["title"]
|
||||
self.prs.append(PRInfo(pr=pr, title=title))
|
||||
|
||||
def remove(self, pr: int):
|
||||
self.prs = [p for p in self.prs if p.pr != pr]
|
||||
|
||||
def print(self, format: OutputFormat = OutputFormat.CONSOLE):
|
||||
match format:
|
||||
case OutputFormat.CONSOLE:
|
||||
console = Console()
|
||||
for pr in self.prs:
|
||||
console.print(f"{pr.pr}: {pr.title}")
|
||||
case OutputFormat.JSON:
|
||||
print(self.model_dump_json())
|
||||
case _:
|
||||
raise ValueError(f"Unknown format: {format}")
|
||||
|
||||
def pr(self, pr: int) -> PRInfo | None:
|
||||
for p in self.prs:
|
||||
if p.pr == pr:
|
||||
return p
|
||||
return None
|
||||
|
||||
|
||||
def _default_path() -> str:
|
||||
if "XDG_STATE_HOME" in os.environ:
|
||||
return f"{os.environ['XDG_STATE_HOME']}/nixprstatus/watchlist.json"
|
||||
return "~/.config/nixprstatus/watchlist.json"
|
||||
|
||||
|
||||
def _ensure_default_path() -> str:
|
||||
p = Path(_default_path()).expanduser()
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
32
tests/helpers/mocks.py
Normal file
32
tests/helpers/mocks.py
Normal file
@ -0,0 +1,32 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
|
||||
def mocked_requests_get(*args, **kwargs):
|
||||
class MockedResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self.json_data = json_data
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return json.loads(self.json_data)
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code not in [200, 201]:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
if "pulls" in args[0]:
|
||||
pr = args[0].split("/")[-1]
|
||||
with open(f"tests/fixtures/pulls_{pr}.json") as f:
|
||||
data = f.read()
|
||||
return MockedResponse(data, 200)
|
||||
elif "compare" in args[0]:
|
||||
branch, commit_sha = args[0].split("/")[-1].split("...")
|
||||
with open(f"tests/fixtures/compare_{branch}_{commit_sha}.json") as f:
|
||||
data = f.read()
|
||||
return MockedResponse(data, 200)
|
||||
elif "comments" in args[0]:
|
||||
pr = args[0].split("/")[-2]
|
||||
with open(f"tests/fixtures/comments_{pr}.json") as f:
|
||||
data = f.read()
|
||||
return MockedResponse(data, 200)
|
@ -1,39 +1,8 @@
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import requests
|
||||
import json
|
||||
|
||||
from nixprstatus.pr import commit_in_branch, pr_merge_status, commits_since
|
||||
|
||||
|
||||
def mocked_requests_get(*args, **kwargs):
|
||||
class MockedResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self.json_data = json_data
|
||||
self.status_code = status_code
|
||||
|
||||
def json(self):
|
||||
return json.loads(self.json_data)
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code not in [200, 201]:
|
||||
raise requests.exceptions.HTTPError()
|
||||
|
||||
if "pulls" in args[0]:
|
||||
pr = args[0].split("/")[-1]
|
||||
with open(f"tests/fixtures/pulls_{pr}.json") as f:
|
||||
data = f.read()
|
||||
return MockedResponse(data, 200)
|
||||
elif "compare" in args[0]:
|
||||
branch, commit_sha = args[0].split("/")[-1].split("...")
|
||||
with open(f"tests/fixtures/compare_{branch}_{commit_sha}.json") as f:
|
||||
data = f.read()
|
||||
return MockedResponse(data, 200)
|
||||
elif "comments" in args[0]:
|
||||
pr = args[0].split("/")[-2]
|
||||
with open(f"tests/fixtures/comments_{pr}.json") as f:
|
||||
data = f.read()
|
||||
return MockedResponse(data, 200)
|
||||
from tests.helpers.mocks import mocked_requests_get
|
||||
|
||||
|
||||
class TestPRMergeStatus(unittest.TestCase):
|
||||
|
34
tests/test_watchlist.py
Normal file
34
tests/test_watchlist.py
Normal file
@ -0,0 +1,34 @@
|
||||
from nixprstatus.watchlist import Watchlist, PRInfo
|
||||
from tempfile import TemporaryDirectory
|
||||
import unittest
|
||||
|
||||
from tests.helpers.mocks import mocked_requests_get
|
||||
|
||||
|
||||
class TestWatchlist(unittest.TestCase):
|
||||
def test_save_load(self):
|
||||
with TemporaryDirectory() as d:
|
||||
filename = f"{d}/test.json"
|
||||
|
||||
watchlist = Watchlist(prs=[PRInfo(pr=1, title="PR 1")])
|
||||
watchlist.to_file(filename)
|
||||
|
||||
# Check that the file was written correctly
|
||||
with open(filename, "r") as f:
|
||||
self.assertEqual(watchlist.model_dump_json(), f.read())
|
||||
|
||||
# Check that the file can be read back
|
||||
loaded = Watchlist.from_file(filename)
|
||||
self.assertEqual(watchlist, loaded)
|
||||
|
||||
@unittest.mock.patch("requests.get", side_effect=mocked_requests_get)
|
||||
def test_add_pr(self, mock_get):
|
||||
w = Watchlist(prs=[])
|
||||
w.add_pr(345583)
|
||||
self.assertEqual(len(w.prs), 1)
|
||||
self.assertEqual(w.prs[0].title, "wireshark: 4.2.6 -> 4.2.7")
|
||||
|
||||
def test_get_pr(self):
|
||||
w = Watchlist(prs=[PRInfo(pr=1, title="PR 1")])
|
||||
self.assertEqual(w.pr(1), PRInfo(pr=1, title="PR 1"))
|
||||
self.assertEqual(w.pr(2), None)
|
Loading…
Reference in New Issue
Block a user