add async file download function

This commit is contained in:
ducklet 2023-02-04 01:02:16 +01:00
parent 758706baa2
commit 0563d49dbc

View file

@ -4,7 +4,7 @@ import logging
import os import os
import tempfile import tempfile
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from hashlib import md5 from hashlib import md5
@ -24,8 +24,10 @@ if config.debug and config.cachedir:
config.cachedir.mkdir(exist_ok=True) config.cachedir.mkdir(exist_ok=True)
_shared_asession = None
_shared_session = None _shared_session = None
_ASession_T = httpx.AsyncClient
_Session_T = httpx.Client _Session_T = httpx.Client
_Response_T = httpx.Response _Response_T = httpx.Response
@ -58,6 +60,32 @@ def _Session() -> _Session_T:
return s return s
@asynccontextmanager
async def asession():
"""Return the shared request session.
The session is shared by all request functions and provides cookie
persistence and connection pooling.
Opening the session before making a request allows you to set headers
or change the retry behavior.
"""
global _shared_asession
if _shared_asession:
yield _shared_asession
return
_shared_asession = _ASession_T()
_shared_asession.headers[
"user-agent"
] = "Mozilla/5.0 Gecko/20100101 unwind/20230203"
try:
async with _shared_asession:
yield _shared_asession
finally:
_shared_asession = None
def _throttle( def _throttle(
times: int, per_seconds: float, jitter: Callable[[], float] | None = None times: int, per_seconds: float, jitter: Callable[[], float] | None = None
) -> Callable[[Callable], Callable]: ) -> Callable[[Callable], Callable]:
@ -298,3 +326,127 @@ def download(
# Fix file attributes. # Fix file attributes.
if resp_lastmod is not None: if resp_lastmod is not None:
os.utime(file_path, (resp_lastmod, resp_lastmod)) os.utime(file_path, (resp_lastmod, resp_lastmod))
async def adownload(
url: str,
*,
to_path: Path | str | None = None,
replace_existing: bool | None = None,
only_if_newer: bool = False,
timeout: float | None = None,
chunk_callback=None,
response_callback=None,
) -> bytes | None:
"""Download a file.
If `to_path` is `None` return the remote content, otherwise write the
content to the given file path.
Existing files will not be overwritten unless `replace_existing` is set.
Setting `only_if_newer` will check if the remote file is newer than the
local file, otherwise the download will be aborted.
"""
if replace_existing is None:
replace_existing = only_if_newer
file_exists = None
if to_path is not None:
to_path = Path(to_path)
file_exists = to_path.exists() and to_path.stat().st_size
if file_exists and not replace_existing:
raise FileExistsError(23, "Would replace existing file", str(to_path))
async with asession() as s:
headers = {}
if file_exists and only_if_newer:
assert to_path
file_lastmod = _last_modified_from_file(to_path)
headers["if-modified-since"] = email.utils.formatdate(
file_lastmod, usegmt=True
)
req = s.build_request(method="GET", url=url, headers=headers, timeout=timeout)
log.debug("⚡️ Loading %s (%a) ...", req.url, dict(req.headers))
resp = await s.send(req, follow_redirects=True, stream=True)
try:
if response_callback is not None:
try:
response_callback(resp)
except:
log.exception("🐛 Error in response callback.")
log.debug(
"☕️ %s -> status: %s; headers: %a",
req.url,
resp.status_code,
dict(resp.headers),
)
if resp.status_code == httpx.codes.NOT_MODIFIED:
log.debug(
"✋ Remote file has not changed, skipping download: %s -> %a",
req.url,
to_path,
)
return
resp.raise_for_status()
if to_path is None:
await resp.aread() # Download the response stream to allow `resp.content` access.
return resp.content
resp_lastmod = _last_modified_from_response(resp)
# Check Last-Modified in case the server ignored If-Modified-Since.
# XXX also check Content-Length?
if file_exists and only_if_newer and resp_lastmod is not None:
assert file_lastmod
if resp_lastmod <= file_lastmod:
log.debug("✋ Local file is newer, skipping download: %a", req.url)
return
# Create intermediate directories if necessary.
download_dir = to_path.parent
download_dir.mkdir(parents=True, exist_ok=True)
# Write content to temp file.
tempdir = download_dir
tempfd, tempfile_path = tempfile.mkstemp(
dir=tempdir, prefix=f".download-{to_path.name}."
)
one_mb = 2**20
chunk_size = 8 * one_mb
try:
log.debug("💾 Writing to temp file %s ...", tempfile_path)
async for chunk in resp.aiter_bytes(chunk_size):
os.write(tempfd, chunk)
if chunk_callback:
try:
chunk_callback(chunk)
except:
log.exception("🐛 Error in chunk callback.")
finally:
os.close(tempfd)
# Move downloaded file to destination.
if to_path.exists():
log.debug("💾 Replacing existing file: %s", to_path)
else:
log.debug("💾 Move to destination: %s", to_path)
if replace_existing:
Path(tempfile_path).replace(to_path)
else:
Path(tempfile_path).rename(to_path)
# Fix file attributes.
if resp_lastmod is not None:
log.debug("💾 Adjusting file timestamp: %s (%s)", to_path, resp_lastmod)
os.utime(to_path, (resp_lastmod, resp_lastmod))
finally:
await resp.aclose()