fix some mypy lint
This commit is contained in:
parent
9d7d80d3a5
commit
7c7a1fcde2
6 changed files with 20 additions and 6 deletions
|
|
@ -17,7 +17,7 @@ class Store:
|
||||||
if path:
|
if path:
|
||||||
self.dbpath = path
|
self.dbpath = path
|
||||||
if self.connection is not None:
|
if self.connection is not None:
|
||||||
return self.connection
|
return
|
||||||
log.debug("Connecting to %s", self.dbpath)
|
log.debug("Connecting to %s", self.dbpath)
|
||||||
self.connection = sqlite3.connect(
|
self.connection = sqlite3.connect(
|
||||||
self.dbpath, isolation_level=None
|
self.dbpath, isolation_level=None
|
||||||
|
|
@ -30,6 +30,7 @@ class Store:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def init(self) -> None:
|
def init(self) -> None:
|
||||||
|
assert self.connection is not None
|
||||||
conn = self.connection
|
conn = self.connection
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"""
|
"""
|
||||||
|
|
@ -57,6 +58,7 @@ class Store:
|
||||||
|
|
||||||
def sync_feeds(self, feeds: Dict[str, Feed]) -> None:
|
def sync_feeds(self, feeds: Dict[str, Feed]) -> None:
|
||||||
"""Write the current state of feeds to store, and load existing info back."""
|
"""Write the current state of feeds to store, and load existing info back."""
|
||||||
|
assert self.connection is not None
|
||||||
conn = self.connection
|
conn = self.connection
|
||||||
conn.executemany(
|
conn.executemany(
|
||||||
"""
|
"""
|
||||||
|
|
@ -117,6 +119,7 @@ class Store:
|
||||||
select id, content, date, link, title from post
|
select id, content, date, link, title from post
|
||||||
where feed_id=? and id in ({qs})
|
where feed_id=? and id in ({qs})
|
||||||
"""
|
"""
|
||||||
|
assert self.connection is not None
|
||||||
conn = self.connection
|
conn = self.connection
|
||||||
posts = [Post(*row) for row in conn.execute(sql, (feed_id, *post_ids))]
|
posts = [Post(*row) for row in conn.execute(sql, (feed_id, *post_ids))]
|
||||||
for post in posts:
|
for post in posts:
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ class Bot:
|
||||||
self.message_handlers.append(callback)
|
self.message_handlers.append(callback)
|
||||||
|
|
||||||
def on_command(self, command: Union[str, Container[str]], callback: MessageHandler):
|
def on_command(self, command: Union[str, Container[str]], callback: MessageHandler):
|
||||||
commands = command = {command} if type(command) is str else command
|
commands = {command} if type(command) is str else command
|
||||||
|
|
||||||
async def guard(message):
|
async def guard(message):
|
||||||
if message.command not in commands:
|
if message.command not in commands:
|
||||||
|
|
@ -161,7 +161,7 @@ class Bot:
|
||||||
for h, t in tasks.items():
|
for h, t in tasks.items():
|
||||||
assert t.done()
|
assert t.done()
|
||||||
try:
|
try:
|
||||||
err = t.exception()
|
err = t.exception() # type: ignore
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
log.error("Message handler took too long to finished: %s", h)
|
log.error("Message handler took too long to finished: %s", h)
|
||||||
if err is not None:
|
if err is not None:
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import fields
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from html import escape as html_escape
|
from html import escape as html_escape
|
||||||
from html.parser import HTMLParser
|
from html.parser import HTMLParser
|
||||||
|
|
@ -311,7 +311,7 @@ class ElementParser(HTMLParser):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def escape_all(dc: dataclass, escape: Callable[[str], str] = html_escape) -> None:
|
def escape_all(dc, escape: Callable[[str], str] = html_escape) -> None:
|
||||||
"""Patch a dataclass to escape all strings."""
|
"""Patch a dataclass to escape all strings."""
|
||||||
for f in fields(dc):
|
for f in fields(dc):
|
||||||
if f.type is str:
|
if f.type is str:
|
||||||
|
|
|
||||||
5
mypy.ini
Normal file
5
mypy.ini
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
[mypy]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
pretty = True
|
||||||
|
python_version = 3.8
|
||||||
|
platform = linux
|
||||||
|
|
@ -15,6 +15,7 @@ class Store:
|
||||||
self.dbpath = path
|
self.dbpath = path
|
||||||
if self.connection is not None:
|
if self.connection is not None:
|
||||||
return self.connection
|
return self.connection
|
||||||
|
assert self.dbpath is not None
|
||||||
self.connection = sqlite3.connect(
|
self.connection = sqlite3.connect(
|
||||||
self.dbpath, isolation_level=None
|
self.dbpath, isolation_level=None
|
||||||
) # auto commit
|
) # auto commit
|
||||||
|
|
@ -45,6 +46,7 @@ class Store:
|
||||||
)
|
)
|
||||||
|
|
||||||
def add(self, posts: Iterable[Post]):
|
def add(self, posts: Iterable[Post]):
|
||||||
|
assert self.connection is not None
|
||||||
sql = f"""
|
sql = f"""
|
||||||
insert into post(content, source, date)
|
insert into post(content, source, date)
|
||||||
values (?, ?, ?)
|
values (?, ?, ?)
|
||||||
|
|
@ -59,6 +61,7 @@ class Store:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _select(self, condition="", params=[]) -> Iterable[Post]:
|
def _select(self, condition="", params=[]) -> Iterable[Post]:
|
||||||
|
assert self.connection is not None
|
||||||
sql = f"select id, content, date, source from post {condition}"
|
sql = f"select id, content, date, source from post {condition}"
|
||||||
for row in self.connection.execute(sql, params):
|
for row in self.connection.execute(sql, params):
|
||||||
id, content, date, source = row
|
id, content, date, source = row
|
||||||
|
|
@ -71,6 +74,7 @@ class Store:
|
||||||
cond = "where id in (select id from post order by random() limit 1)"
|
cond = "where id in (select id from post order by random() limit 1)"
|
||||||
for post in self._select(cond):
|
for post in self._select(cond):
|
||||||
return post
|
return post
|
||||||
|
return None
|
||||||
|
|
||||||
def search(self, term, skip: int = 0) -> Iterable[Post]:
|
def search(self, term, skip: int = 0) -> Iterable[Post]:
|
||||||
cond = "where content like ? order by date desc limit -1 offset ?"
|
cond = "where content like ? order by date desc limit -1 offset ?"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ if [ "$1" = '--fix' ]; then
|
||||||
black .
|
black .
|
||||||
isort --profile black .
|
isort --profile black .
|
||||||
autoflake --in-place --recursive .
|
autoflake --in-place --recursive .
|
||||||
|
mypy . || : # ignore
|
||||||
)
|
)
|
||||||
exit
|
exit
|
||||||
fi
|
fi
|
||||||
|
|
@ -14,6 +15,7 @@ error=0
|
||||||
|
|
||||||
(set -x; black --check .) || error=$?
|
(set -x; black --check .) || error=$?
|
||||||
(set -x; isort --profile black --check .) || error=$?
|
(set -x; isort --profile black --check .) || error=$?
|
||||||
(set -x; autoflake --check --recursive .) || error=$?
|
(set -x; autoflake --check --recursive . | uniq) || error=$?
|
||||||
|
(set -x; mypy .) || : # ignore
|
||||||
|
|
||||||
exit "$error"
|
exit "$error"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue