improve typing

This commit is contained in:
ducklet 2023-11-26 18:28:17 +01:00
parent 6d0c61fceb
commit 22ea553f48
3 changed files with 29 additions and 16 deletions

View file

@ -3,7 +3,7 @@ import contextlib
import logging import logging
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, Literal, Type, TypeVar from typing import Any, AsyncGenerator, Iterable, Literal, Type, TypeVar
import sqlalchemy as sa import sqlalchemy as sa
from databases import Database from databases import Database
@ -472,12 +472,12 @@ async def find_ratings(
.limit(limit_rows) .limit(limit_rows)
) )
async with locked_connection() as conn: async with locked_connection() as conn:
rows = conn.iterate(query) rating_rows: AsyncGenerator[Rating, None] = conn.iterate(query) # type: ignore
movie_ids = [r.movie_id async for r in rows] movie_ids = [r.movie_id async for r in rating_rows]
if include_unrated and len(movie_ids) < limit_rows: if include_unrated and len(movie_ids) < limit_rows:
query = ( query = (
sa.select(movies.c.id.label("movie_id")) sa.select(movies.c.id)
.distinct() .distinct()
.where(movies.c.id.not_in(movie_ids), *conditions) .where(movies.c.id.not_in(movie_ids), *conditions)
.order_by( .order_by(
@ -488,8 +488,8 @@ async def find_ratings(
.limit(limit_rows - len(movie_ids)) .limit(limit_rows - len(movie_ids))
) )
async with locked_connection() as conn: async with locked_connection() as conn:
rows = conn.iterate(query) movie_rows: AsyncGenerator[Movie, None] = conn.iterate(query) # type: ignore
movie_ids += [r.movie_id async for r in rows] movie_ids += [r.id async for r in movie_rows]
return await ratings_for_movie_ids(ids=movie_ids) return await ratings_for_movie_ids(ids=movie_ids)

View file

@ -4,6 +4,8 @@ from collections import namedtuple
from datetime import datetime from datetime import datetime
from urllib.parse import urljoin from urllib.parse import urljoin
import bs4
from . import db from . import db
from .models import Movie, Rating, User from .models import Movie, Rating, User
from .request import asession, asoup_from_url, cache_path from .request import asession, asoup_from_url, cache_path
@ -43,7 +45,7 @@ async def refresh_user_ratings_from_imdb(stop_on_dupe: bool = True):
try: try:
async for rating, is_updated in load_ratings(user.imdb_id): async for rating, is_updated in load_ratings(user.imdb_id):
assert rating.user.id == user.id assert rating.user is not None and rating.user.id == user.id
if stop_on_dupe and not is_updated: if stop_on_dupe and not is_updated:
break break
@ -154,13 +156,18 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
soup = await asoup_from_url(url) soup = await asoup_from_url(url)
meta = soup.find("meta", property="pageId") if (meta := soup.find("meta", property="pageId")) is None:
headline = soup.h1 raise RuntimeError("No pageId found.")
assert meta is not None and headline is not None assert isinstance(meta, bs4.Tag)
imdb_id = meta["content"] imdb_id = meta["content"]
assert isinstance(imdb_id, str)
user = await db.get(User, imdb_id=imdb_id) or User( user = await db.get(User, imdb_id=imdb_id) or User(
imdb_id=imdb_id, name="", secret="" imdb_id=imdb_id, name="", secret=""
) )
if (headline := soup.h1) is None:
raise RuntimeError("No headline found.")
assert isinstance(headline.string, str)
if match := find_name(headline.string): if match := find_name(headline.string):
user.name = match["name"] user.name = match["name"]
@ -184,9 +191,15 @@ async def parse_page(url: str) -> tuple[list[Rating], str | None]:
ratings.append(rating) ratings.append(rating)
footer = soup.find("div", "footer") next_url = None
assert footer is not None if (footer := soup.find("div", "footer")) is None:
next_url = urljoin(url, footer.find(string=re.compile(r"Next")).parent["href"]) raise RuntimeError("No footer found.")
assert isinstance(footer, bs4.Tag)
if (next_link := footer.find("a", string="Next")) is not None:
assert isinstance(next_link, bs4.Tag)
next_href = next_link["href"]
assert isinstance(next_href, str)
next_url = urljoin(url, next_href)
return (ratings, next_url if url != next_url else None) return (ratings, next_url if url != next_url else None)

View file

@ -40,7 +40,7 @@ metadata = mapper_registry.metadata
def annotations(tp: Type) -> tuple | None: def annotations(tp: Type) -> tuple | None:
return tp.__metadata__ if hasattr(tp, "__metadata__") else None return tp.__metadata__ if hasattr(tp, "__metadata__") else None # type: ignore
def fields(class_or_instance): def fields(class_or_instance):
@ -125,7 +125,7 @@ def asplain(
if filter_fields is not None and f.name not in filter_fields: if filter_fields is not None and f.name not in filter_fields:
continue continue
target = f.type target: Any = f.type
# XXX this doesn't properly support any kind of nested types # XXX this doesn't properly support any kind of nested types
if (otype := optional_type(f.type)) is not None: if (otype := optional_type(f.type)) is not None:
target = otype target = otype
@ -169,7 +169,7 @@ def fromplain(cls: Type[T], d: Mapping, *, serialized: bool = False) -> T:
dd: JSONObject = {} dd: JSONObject = {}
for f in fields(cls): for f in fields(cls):
target = f.type target: Any = f.type
otype = optional_type(f.type) otype = optional_type(f.type)
is_opt = otype is not None is_opt = otype is not None
if is_opt: if is_opt: