From 22ea553f48ea9d398f0f6bf24eb46db289214c20 Mon Sep 17 00:00:00 2001 From: ducklet Date: Sun, 26 Nov 2023 18:28:17 +0100 Subject: [PATCH] improve typing --- unwind/db.py | 12 ++++++------ unwind/imdb.py | 27 ++++++++++++++++++++------- unwind/models.py | 6 +++--- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/unwind/db.py b/unwind/db.py index 7a0169d..278c0c0 100644 --- a/unwind/db.py +++ b/unwind/db.py @@ -3,7 +3,7 @@ import contextlib import logging import threading from pathlib import Path -from typing import Any, Iterable, Literal, Type, TypeVar +from typing import Any, AsyncGenerator, Iterable, Literal, Type, TypeVar import sqlalchemy as sa from databases import Database @@ -472,12 +472,12 @@ async def find_ratings( .limit(limit_rows) ) async with locked_connection() as conn: - rows = conn.iterate(query) - movie_ids = [r.movie_id async for r in rows] + rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore + movie_ids = [r.movie_id async for r in rating_rows] if include_unrated and len(movie_ids) < limit_rows: query = ( - sa.select(movies.c.id.label("movie_id")) + sa.select(movies.c.id) .distinct() .where(movies.c.id.not_in(movie_ids), *conditions) .order_by( @@ -488,8 +488,8 @@ async def find_ratings( .limit(limit_rows - len(movie_ids)) ) async with locked_connection() as conn: - rows = conn.iterate(query) - movie_ids += [r.movie_id async for r in rows] + movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore + movie_ids += [r.id async for r in movie_rows] return await ratings_for_movie_ids(ids=movie_ids) diff --git a/unwind/imdb.py b/unwind/imdb.py index 477ec64..6858fc7 100644 --- a/unwind/imdb.py +++ b/unwind/imdb.py @@ -4,6 +4,8 @@ from collections import namedtuple from datetime import datetime from urllib.parse import urljoin +import bs4 + from . import db from .models import Movie, Rating, User from .request import asession, asoup_from_url, cache_path @@ -43,7 +45,7 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True): try: async for rating, is_updated in load_ratings(user.imdb_id): - assert rating.user.id == user.id + assert rating.user is not None and rating.user.id == user.id if stop_on_dupe and not is_updated: break @@ -154,13 +156,18 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]: soup = await asoup_from_url(url) - meta = soup.find("meta", property="pageId") - headline = soup.h1 - assert meta is not None and headline is not None + if (meta := soup.find("meta", property="pageId")) is None: + raise RuntimeError("No pageId found.") + assert isinstance(meta, bs4.Tag) imdb_id = meta["content"] + assert isinstance(imdb_id, str) user = await db.get(User, imdb_id=imdb_id) or User( imdb_id=imdb_id, name="", secret="" ) + + if (headline := soup.h1) is None: + raise RuntimeError("No headline found.") + assert isinstance(headline.string, str) if match := find_name(headline.string): user.name = match["name"] @@ -184,9 +191,15 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]: ratings.append(rating) - footer = soup.find("div", "footer") - assert footer is not None - next_url = urljoin(url, footer.find(string=re.compile(r"Next")).parent["href"]) + next_url = None + if (footer := soup.find("div", "footer")) is None: + raise RuntimeError("No footer found.") + assert isinstance(footer, bs4.Tag) + if (next_link := footer.find("a", string="Next")) is not None: + assert isinstance(next_link, bs4.Tag) + next_href = next_link["href"] + assert isinstance(next_href, str) + next_url = urljoin(url, next_href) return (ratings, next_url if url != next_url else None) diff --git a/unwind/models.py b/unwind/models.py index 609614a..ff961fc 100644 --- a/unwind/models.py +++ b/unwind/models.py @@ -40,7 +40,7 @@ metadata = mapper_registry.metadata def annotations(tp: Type) -> tuple | None: - return tp.__metadata__ if hasattr(tp, "__metadata__") else None + return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore def fields(class_or_instance): @@ -125,7 +125,7 @@ def asplain( if filter_fields is not None and f.name not in filter_fields: continue - target = f.type + target: Any = f.type # XXX this doesn't properly support any kind of nested types if (otype := optional_type(f.type)) is not None: target = otype @@ -169,7 +169,7 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T: dd: JSONObject = {} for f in fields(cls): - target = f.type + target: Any = f.type otype = optional_type(f.type) is_opt = otype is not None if is_opt: