diff --git a/unwind/request.py b/unwind/request.py index b3aed67..057cf7b 100644 --- a/unwind/request.py +++ b/unwind/request.py @@ -4,7 +4,7 @@ import logging import os import tempfile from collections import deque -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field from functools import wraps from hashlib import md5 @@ -24,8 +24,10 @@ if config.debug and config.cachedir: config.cachedir.mkdir(exist_ok=True) +_shared_asession = None _shared_session = None +_ASession_T = httpx.AsyncClient _Session_T = httpx.Client _Response_T = httpx.Response @@ -58,6 +60,32 @@ def _Session() -> _Session_T: return s +@asynccontextmanager +async def asession(): + """Return the shared request session. + + The session is shared by all request functions and provides cookie + persistence and connection pooling. + Opening the session before making a request allows you to set headers + or change the retry behavior. + """ + global _shared_asession + + if _shared_asession: + yield _shared_asession + return + + _shared_asession = _ASession_T() + _shared_asession.headers[ + "user-agent" + ] = "Mozilla/5.0 Gecko/20100101 unwind/20230203" + try: + async with _shared_asession: + yield _shared_asession + finally: + _shared_asession = None + + def _throttle( times: int, per_seconds: float, jitter: Callable[[], float] | None = None ) -> Callable[[Callable], Callable]: @@ -298,3 +326,127 @@ def download( # Fix file attributes. if resp_lastmod is not None: os.utime(file_path, (resp_lastmod, resp_lastmod)) + + +async def adownload( + url: str, + *, + to_path: Path | str | None = None, + replace_existing: bool | None = None, + only_if_newer: bool = False, + timeout: float | None = None, + chunk_callback=None, + response_callback=None, +) -> bytes | None: + """Download a file. + + If `to_path` is `None` return the remote content, otherwise write the + content to the given file path. + Existing files will not be overwritten unless `replace_existing` is set. + Setting `only_if_newer` will check if the remote file is newer than the + local file, otherwise the download will be aborted. + """ + if replace_existing is None: + replace_existing = only_if_newer + + file_exists = None + if to_path is not None: + to_path = Path(to_path) + + file_exists = to_path.exists() and to_path.stat().st_size + if file_exists and not replace_existing: + raise FileExistsError(23, "Would replace existing file", str(to_path)) + + async with asession() as s: + headers = {} + if file_exists and only_if_newer: + assert to_path + file_lastmod = _last_modified_from_file(to_path) + headers["if-modified-since"] = email.utils.formatdate( + file_lastmod, usegmt=True + ) + + req = s.build_request(method="GET", url=url, headers=headers, timeout=timeout) + + log.debug("⚡️ Loading %s (%a) ...", req.url, dict(req.headers)) + resp = await s.send(req, follow_redirects=True, stream=True) + + try: + if response_callback is not None: + try: + response_callback(resp) + except: + log.exception("🐛 Error in response callback.") + + log.debug( + "☕️ %s -> status: %s; headers: %a", + req.url, + resp.status_code, + dict(resp.headers), + ) + + if resp.status_code == httpx.codes.NOT_MODIFIED: + log.debug( + "✋ Remote file has not changed, skipping download: %s -> %a", + req.url, + to_path, + ) + return + + resp.raise_for_status() + + if to_path is None: + await resp.aread() # Download the response stream to allow `resp.content` access. + return resp.content + + resp_lastmod = _last_modified_from_response(resp) + + # Check Last-Modified in case the server ignored If-Modified-Since. + # XXX also check Content-Length? + if file_exists and only_if_newer and resp_lastmod is not None: + assert file_lastmod + + if resp_lastmod <= file_lastmod: + log.debug("✋ Local file is newer, skipping download: %a", req.url) + return + + # Create intermediate directories if necessary. + download_dir = to_path.parent + download_dir.mkdir(parents=True, exist_ok=True) + + # Write content to temp file. + tempdir = download_dir + tempfd, tempfile_path = tempfile.mkstemp( + dir=tempdir, prefix=f".download-{to_path.name}." + ) + one_mb = 2**20 + chunk_size = 8 * one_mb + try: + log.debug("💾 Writing to temp file %s ...", tempfile_path) + async for chunk in resp.aiter_bytes(chunk_size): + os.write(tempfd, chunk) + if chunk_callback: + try: + chunk_callback(chunk) + except: + log.exception("🐛 Error in chunk callback.") + finally: + os.close(tempfd) + + # Move downloaded file to destination. + if to_path.exists(): + log.debug("💾 Replacing existing file: %s", to_path) + else: + log.debug("💾 Move to destination: %s", to_path) + if replace_existing: + Path(tempfile_path).replace(to_path) + else: + Path(tempfile_path).rename(to_path) + + # Fix file attributes. + if resp_lastmod is not None: + log.debug("💾 Adjusting file timestamp: %s (%s)", to_path, resp_lastmod) + os.utime(to_path, (resp_lastmod, resp_lastmod)) + + finally: + await resp.aclose()