fix: some lint reported by Ruff
This commit is contained in:
parent
e9a58ed40e
commit
8fc594b947
11 changed files with 73 additions and 49 deletions
|
|
@ -1,4 +1,5 @@
|
||||||
[project]
|
[project]
|
||||||
|
name = "unwind"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
|
|
|
||||||
|
|
@ -164,14 +164,14 @@ async def test_find_ratings(conn: db.Connection):
|
||||||
u1 = models.User(
|
u1 = models.User(
|
||||||
imdb_id="u00001",
|
imdb_id="u00001",
|
||||||
name="User1",
|
name="User1",
|
||||||
secret="secret1",
|
secret="secret1", # noqa: S106
|
||||||
)
|
)
|
||||||
await db.add(conn, u1)
|
await db.add(conn, u1)
|
||||||
|
|
||||||
u2 = models.User(
|
u2 = models.User(
|
||||||
imdb_id="u00002",
|
imdb_id="u00002",
|
||||||
name="User2",
|
name="User2",
|
||||||
secret="secret2",
|
secret="secret2", # noqa: S106
|
||||||
)
|
)
|
||||||
await db.add(conn, u2)
|
await db.add(conn, u2)
|
||||||
|
|
||||||
|
|
@ -271,14 +271,14 @@ async def test_ratings_for_movies(conn: db.Connection):
|
||||||
u1 = models.User(
|
u1 = models.User(
|
||||||
imdb_id="u00001",
|
imdb_id="u00001",
|
||||||
name="User1",
|
name="User1",
|
||||||
secret="secret1",
|
secret="secret1", # noqa: S106
|
||||||
)
|
)
|
||||||
await db.add(conn, u1)
|
await db.add(conn, u1)
|
||||||
|
|
||||||
u2 = models.User(
|
u2 = models.User(
|
||||||
imdb_id="u00002",
|
imdb_id="u00002",
|
||||||
name="User2",
|
name="User2",
|
||||||
secret="secret2",
|
secret="secret2", # noqa: S106
|
||||||
)
|
)
|
||||||
await db.add(conn, u2)
|
await db.add(conn, u2)
|
||||||
|
|
||||||
|
|
@ -296,7 +296,7 @@ async def test_ratings_for_movies(conn: db.Connection):
|
||||||
|
|
||||||
movie_ids = [m1.id]
|
movie_ids = [m1.id]
|
||||||
user_ids = []
|
user_ids = []
|
||||||
assert tuple() == tuple(
|
assert () == tuple(
|
||||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -308,7 +308,7 @@ async def test_ratings_for_movies(conn: db.Connection):
|
||||||
|
|
||||||
movie_ids = [m2.id]
|
movie_ids = [m2.id]
|
||||||
user_ids = [u2.id]
|
user_ids = [u2.id]
|
||||||
assert tuple() == tuple(
|
assert () == tuple(
|
||||||
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
await db.ratings_for_movies(conn, movie_ids=movie_ids, user_ids=user_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -336,14 +336,14 @@ async def test_find_movies(conn: db.Connection):
|
||||||
u1 = models.User(
|
u1 = models.User(
|
||||||
imdb_id="u00001",
|
imdb_id="u00001",
|
||||||
name="User1",
|
name="User1",
|
||||||
secret="secret1",
|
secret="secret1", # noqa: S106
|
||||||
)
|
)
|
||||||
await db.add(conn, u1)
|
await db.add(conn, u1)
|
||||||
|
|
||||||
u2 = models.User(
|
u2 = models.User(
|
||||||
imdb_id="u00002",
|
imdb_id="u00002",
|
||||||
name="User2",
|
name="User2",
|
||||||
secret="secret2",
|
secret="secret2", # noqa: S106
|
||||||
)
|
)
|
||||||
await db.add(conn, u2)
|
await db.add(conn, u2)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
@ -24,7 +24,7 @@ def authorized_client() -> TestClient:
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def admin_client() -> TestClient:
|
def admin_client() -> TestClient:
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
for token in config.api_credentials.values():
|
for token in config.api_credentials.values(): # noqa: B007
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("No bearer tokens configured.")
|
raise RuntimeError("No bearer tokens configured.")
|
||||||
|
|
@ -39,7 +39,7 @@ async def test_get_ratings_for_group(
|
||||||
user = models.User(
|
user = models.User(
|
||||||
imdb_id="ur12345678",
|
imdb_id="ur12345678",
|
||||||
name="user-1",
|
name="user-1",
|
||||||
secret="secret-1",
|
secret="secret-1", # noqa: S106
|
||||||
groups=[],
|
groups=[],
|
||||||
)
|
)
|
||||||
group = models.Group(
|
group = models.Group(
|
||||||
|
|
@ -69,7 +69,7 @@ async def test_get_ratings_for_group(
|
||||||
await db.add(conn, movie)
|
await db.add(conn, movie)
|
||||||
|
|
||||||
rating = models.Rating(
|
rating = models.Rating(
|
||||||
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now()
|
movie_id=movie.id, user_id=user.id, score=66, rating_date=datetime.now(tz=UTC)
|
||||||
)
|
)
|
||||||
await db.add(conn, rating)
|
await db.add(conn, rating)
|
||||||
|
|
||||||
|
|
@ -190,7 +190,7 @@ async def test_list_users(
|
||||||
m = models.User(
|
m = models.User(
|
||||||
imdb_id="ur12345678",
|
imdb_id="ur12345678",
|
||||||
name="user-1",
|
name="user-1",
|
||||||
secret="secret-1",
|
secret="secret-1", # noqa: S106
|
||||||
groups=[],
|
groups=[],
|
||||||
)
|
)
|
||||||
await db.add(conn, m)
|
await db.add(conn, m)
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
from .web import create_app
|
from .web import create_app as create_app
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,7 @@ def main():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
args = getargs()
|
args = getargs()
|
||||||
except:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
if args.mode == "load-user-ratings-from-imdb":
|
if args.mode == "load-user-ratings-from-imdb":
|
||||||
|
|
|
||||||
|
|
@ -618,7 +618,7 @@ async def find_movies(
|
||||||
limit_rows: int = 10,
|
limit_rows: int = 10,
|
||||||
skip_rows: int = 0,
|
skip_rows: int = 0,
|
||||||
include_unrated: bool = False,
|
include_unrated: bool = False,
|
||||||
user_ids: list[ULID] = [],
|
user_ids: list[ULID] | None = None,
|
||||||
) -> Iterable[tuple[Movie, list[Rating]]]:
|
) -> Iterable[tuple[Movie, list[Rating]]]:
|
||||||
conditions = []
|
conditions = []
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -105,10 +105,10 @@ find_runtime = re.compile(r"((?P<h>\d+) hr)? ?((?P<m>\d+) min)?").fullmatch
|
||||||
find_runtime_2 = re.compile(r"((?P<h>\d+)h )?((?P<m>\d+)m)?").fullmatch
|
find_runtime_2 = re.compile(r"((?P<h>\d+)h )?((?P<m>\d+)m)?").fullmatch
|
||||||
# find_year: e.g. "(1992)"
|
# find_year: e.g. "(1992)"
|
||||||
find_year = re.compile(
|
find_year = re.compile(
|
||||||
r"(\([IVX]+\) )?\((?P<year>\d{4})(–( |\d{4})| (?P<type>[^)]+))?\)"
|
r"(\([IVX]+\) )?\((?P<year>\d{4})(–( |\d{4})| (?P<type>[^)]+))?\)" # noqa: RUF001
|
||||||
).fullmatch
|
).fullmatch
|
||||||
# find_year_2: e.g. "2024", "1971–2003", "2024–"
|
# find_year_2: e.g. "2024", "1971–2003", "2024–" # noqa: RUF003
|
||||||
find_year_2 = re.compile(r"(?P<year>\d{4})(–(?P<end_year>\d{4})?)?").fullmatch
|
find_year_2 = re.compile(r"(?P<year>\d{4})(–(?P<end_year>\d{4})?)?").fullmatch # noqa: RUF001
|
||||||
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
find_movie_id = re.compile(r"/title/(?P<id>tt\d+)/").search
|
||||||
find_movie_name = re.compile(r"\d+\. (?P<name>.+)").fullmatch
|
find_movie_name = re.compile(r"\d+\. (?P<name>.+)").fullmatch
|
||||||
# find_vote_count: e.g. "(5.9K)", "(1K)", "(8)"
|
# find_vote_count: e.g. "(5.9K)", "(1K)", "(8)"
|
||||||
|
|
@ -129,7 +129,7 @@ def _movie_and_rating_from_item_legacy(item: bs4.Tag) -> tuple[Movie, Rating]:
|
||||||
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
genres = (genre := item.find("span", "genre")) and genre.string or ""
|
||||||
movie = Movie(
|
movie = Movie(
|
||||||
title=item.h3.a.string.strip(),
|
title=item.h3.a.string.strip(),
|
||||||
genres=set(s.strip() for s in genres.split(",")),
|
genres={s.strip() for s in genres.split(",")},
|
||||||
)
|
)
|
||||||
|
|
||||||
episode_br = item.h3.br
|
episode_br = item.h3.br
|
||||||
|
|
|
||||||
|
|
@ -348,9 +348,9 @@ class Movie:
|
||||||
if not self._is_lazy:
|
if not self._is_lazy:
|
||||||
return
|
return
|
||||||
|
|
||||||
for field in fields(Movie):
|
for f in fields(Movie):
|
||||||
if getattr(self, field.name) is None and callable(field.default_factory):
|
if getattr(self, f.name) is None and callable(f.default_factory):
|
||||||
setattr(self, field.name, field.default_factory())
|
setattr(self, f.name, f.default_factory())
|
||||||
|
|
||||||
self._is_lazy = False
|
self._is_lazy = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ def _throttle(
|
||||||
calls: deque[float] = deque(maxlen=times)
|
calls: deque[float] = deque(maxlen=times)
|
||||||
|
|
||||||
if jitter is None:
|
if jitter is None:
|
||||||
jitter = lambda: 0.0
|
jitter = lambda: 0.0 # noqa: E731
|
||||||
|
|
||||||
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
|
|
@ -125,12 +125,12 @@ def cache_path(req) -> Path | None:
|
||||||
if not config.cachedir:
|
if not config.cachedir:
|
||||||
return
|
return
|
||||||
sig = repr(req.url) # + repr(sorted(req.headers.items()))
|
sig = repr(req.url) # + repr(sorted(req.headers.items()))
|
||||||
return config.cachedir / md5(sig.encode()).hexdigest()
|
return config.cachedir / md5(sig.encode()).hexdigest() # noqa: S324
|
||||||
|
|
||||||
|
|
||||||
@_throttle(1, 1, random)
|
@_throttle(1, 1, random)
|
||||||
async def _ahttp_get(s: _ASession_T, url: str, *args, **kwds) -> _Response_T:
|
async def _ahttp_get(s: _ASession_T, url: str, *args, **kwds) -> _Response_T:
|
||||||
req = s.build_request(method="GET", url=url, *args, **kwds)
|
req = s.build_request(*args, method="GET", url=url, **kwds)
|
||||||
|
|
||||||
cachefile = cache_path(req) if config.debug else None
|
cachefile = cache_path(req) if config.debug else None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Literal
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
|
||||||
def b64encode(b: bytes) -> str:
|
def b64encode(b: bytes) -> str:
|
||||||
|
|
@ -16,11 +16,21 @@ def b64padded(s: str) -> str:
|
||||||
return s + "=" * (4 - len(s) % 4)
|
return s + "=" * (4 - len(s) % 4)
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_params(params: dict[str, Any]) -> str:
|
||||||
|
return ",".join(f"{k}={v}" for k, v in params.items())
|
||||||
|
|
||||||
|
|
||||||
|
class _PhcScryptParams(TypedDict, total=False):
|
||||||
|
n: int
|
||||||
|
r: int
|
||||||
|
p: int
|
||||||
|
|
||||||
|
|
||||||
def phc_scrypt(
|
def phc_scrypt(
|
||||||
secret: bytes,
|
secret: bytes,
|
||||||
*,
|
*,
|
||||||
salt: bytes | None = None,
|
salt: bytes | None = None,
|
||||||
params: dict[Literal["n", "r", "p"], int] = {},
|
params: _PhcScryptParams = {}, # noqa: B006
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return the scrypt expanded secret in PHC string format.
|
"""Return the scrypt expanded secret in PHC string format.
|
||||||
|
|
||||||
|
|
@ -39,10 +49,14 @@ def phc_scrypt(
|
||||||
# maxmem = 2 * 128 * n * r * p
|
# maxmem = 2 * 128 * n * r * p
|
||||||
hashed_secret = hashlib.scrypt(secret, salt=salt, n=n, r=r, p=p)
|
hashed_secret = hashlib.scrypt(secret, salt=salt, n=n, r=r, p=p)
|
||||||
|
|
||||||
encoded_params = ",".join(f"{k}={v}" for k, v in {"n": n, "r": r, "p": p}.items())
|
|
||||||
phc = "".join(
|
phc = "".join(
|
||||||
f"${x}"
|
f"${x}"
|
||||||
for x in ["scrypt", encoded_params, b64encode(salt), b64encode(hashed_secret)]
|
for x in [
|
||||||
|
"scrypt",
|
||||||
|
_encode_params({"n": n, "r": r, "p": p}),
|
||||||
|
b64encode(salt),
|
||||||
|
b64encode(hashed_secret),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return phc
|
return phc
|
||||||
|
|
@ -54,19 +68,27 @@ def phc_compare(*, secret: str, phc_string: str) -> bool:
|
||||||
if args["id"] != "scrypt":
|
if args["id"] != "scrypt":
|
||||||
raise ValueError(f"Algorithm not supported: {args['id']}")
|
raise ValueError(f"Algorithm not supported: {args['id']}")
|
||||||
|
|
||||||
assert type(args["params"]) is dict
|
|
||||||
encoded = phc_scrypt(b64decode(secret), salt=args["salt"], params=args["params"])
|
encoded = phc_scrypt(b64decode(secret), salt=args["salt"], params=args["params"])
|
||||||
|
|
||||||
return secrets.compare_digest(encoded, phc_string)
|
return secrets.compare_digest(encoded, phc_string)
|
||||||
|
|
||||||
|
|
||||||
def parse_phc(s: str):
|
class _PhcParts(TypedDict):
|
||||||
parts = dict.fromkeys(["id", "version", "params", "salt", "hash"])
|
# $<id>[$v=<version>][$<param>=<value>(,<param>=<value>)*][$<salt>[$<hash>]]
|
||||||
|
id: str # the symbolic name for the function
|
||||||
|
version: int | None # the algorithm version
|
||||||
|
params: dict[str, int]
|
||||||
|
salt: bytes | None
|
||||||
|
hash: bytes | None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_phc(s: str) -> _PhcParts:
|
||||||
|
parts = _PhcParts(id="", version=None, params={}, salt=None, hash=None)
|
||||||
|
|
||||||
_, parts["id"], *rest = s.split("$")
|
_, parts["id"], *rest = s.split("$")
|
||||||
|
|
||||||
if rest and rest[0].startswith("v="):
|
if rest and rest[0].startswith("v="):
|
||||||
parts["version"] = rest.pop(0)
|
parts["version"] = int(rest.pop(0))
|
||||||
if rest and "=" in rest[0]:
|
if rest and "=" in rest[0]:
|
||||||
parts["params"] = {
|
parts["params"] = {
|
||||||
kv[0]: int(kv[1])
|
kv[0]: int(kv[1])
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from starlette.middleware import Middleware
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
from starlette.middleware.gzip import GZipMiddleware
|
from starlette.middleware.gzip import GZipMiddleware
|
||||||
|
from starlette.requests import HTTPConnection
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Mount, Route
|
||||||
|
|
||||||
|
|
@ -47,17 +48,17 @@ class BearerAuthBackend(AuthenticationBackend):
|
||||||
def __init__(self, credentials: dict[str, str]):
|
def __init__(self, credentials: dict[str, str]):
|
||||||
self.admin_tokens = {v: k for k, v in credentials.items()}
|
self.admin_tokens = {v: k for k, v in credentials.items()}
|
||||||
|
|
||||||
async def authenticate(self, request):
|
async def authenticate(self, conn: HTTPConnection):
|
||||||
if "Authorization" not in request.headers:
|
if "Authorization" not in conn.headers:
|
||||||
return
|
return
|
||||||
|
|
||||||
# XXX should we remove the auth header after reading, for security reasons?
|
# XXX should we remove the auth header after reading, for security reasons?
|
||||||
|
|
||||||
auth = request.headers["Authorization"]
|
auth = conn.headers["Authorization"]
|
||||||
try:
|
try:
|
||||||
scheme, credentials = auth.split()
|
scheme, credentials = auth.split()
|
||||||
except ValueError:
|
except ValueError as err:
|
||||||
raise AuthenticationError("Invalid auth credentials")
|
raise AuthenticationError("Invalid auth credentials") from err
|
||||||
|
|
||||||
roles = []
|
roles = []
|
||||||
|
|
||||||
|
|
@ -72,8 +73,8 @@ class BearerAuthBackend(AuthenticationBackend):
|
||||||
elif scheme.lower() == "basic":
|
elif scheme.lower() == "basic":
|
||||||
try:
|
try:
|
||||||
name, secret = b64decode(credentials).decode().split(":")
|
name, secret = b64decode(credentials).decode().split(":")
|
||||||
except:
|
except Exception as err:
|
||||||
raise AuthenticationError("Invalid auth credentials")
|
raise AuthenticationError("Invalid auth credentials") from err
|
||||||
user = AuthedUser(name, secret)
|
user = AuthedUser(name, secret)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -113,7 +114,7 @@ def as_int(
|
||||||
return max
|
return max
|
||||||
return x
|
return x
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
if default is None:
|
if default is None:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
@ -127,8 +128,8 @@ def as_ulid(s: str) -> ULID:
|
||||||
|
|
||||||
return ULID(s)
|
return ULID(s)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError as err:
|
||||||
raise HTTPException(422, "Not a valid ULID.")
|
raise HTTPException(422, "Not a valid ULID.") from err
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|
@ -146,8 +147,8 @@ async def json_from_body(request, keys: list[str] | None = None):
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
except JSONDecodeError:
|
except JSONDecodeError as err:
|
||||||
raise HTTPException(422, "Invalid JSON content.")
|
raise HTTPException(422, "Invalid JSON content.") from err
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
return data
|
return data
|
||||||
|
|
@ -155,7 +156,7 @@ async def json_from_body(request, keys: list[str] | None = None):
|
||||||
try:
|
try:
|
||||||
return [data[k] for k in keys]
|
return [data[k] for k in keys]
|
||||||
except KeyError as err:
|
except KeyError as err:
|
||||||
raise HTTPException(422, f"Missing data for key: {err.args[0]}")
|
raise HTTPException(422, f"Missing data for key: {err.args[0]}") from err
|
||||||
|
|
||||||
|
|
||||||
def is_admin(request):
|
def is_admin(request):
|
||||||
|
|
@ -508,8 +509,8 @@ async def modify_user(request):
|
||||||
if "secret" in data:
|
if "secret" in data:
|
||||||
try:
|
try:
|
||||||
secret = b64decode(data["secret"])
|
secret = b64decode(data["secret"])
|
||||||
except:
|
except Exception as err:
|
||||||
raise HTTPException(422, "Invalid secret.")
|
raise HTTPException(422, "Invalid secret.") from err
|
||||||
|
|
||||||
user.secret = phc_scrypt(secret)
|
user.secret = phc_scrypt(secret)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue