123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
import logging
|
|
import sqlite3
|
|
from datetime import datetime, timezone
|
|
from typing import *
|
|
|
|
from .models import Feed, Post
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class Store:
|
|
def __init__(self, dbpath: Optional[str] = None):
|
|
self.dbpath = dbpath
|
|
self.connection: Optional[sqlite3.Connection] = None
|
|
|
|
def connect(self, path: Optional[str] = None) -> None:
|
|
if path:
|
|
self.dbpath = path
|
|
if self.connection is not None:
|
|
return self.connection
|
|
log.debug("Connecting to %s", self.dbpath)
|
|
self.connection = sqlite3.connect(
|
|
self.dbpath, isolation_level=None
|
|
) # auto commit
|
|
self.init()
|
|
|
|
def disconnect(self) -> None:
|
|
conn = self.connection
|
|
if conn:
|
|
conn.close()
|
|
|
|
def init(self) -> None:
|
|
conn = self.connection
|
|
conn.execute(
|
|
"""
|
|
create table if not exists feed (
|
|
id text primary key not null,
|
|
url text unique not null,
|
|
active integer not null,
|
|
etag text,
|
|
modified text
|
|
)
|
|
"""
|
|
)
|
|
conn.execute(
|
|
"""
|
|
create table if not exists post (
|
|
id text primary key not null,
|
|
feed_id text not null references feed(id) on delete cascade,
|
|
content text,
|
|
date text,
|
|
link text,
|
|
title text
|
|
)
|
|
"""
|
|
)
|
|
|
|
def sync_feeds(self, feeds: Dict[str, Feed]) -> None:
|
|
"""Write the current state of feeds to store, and load existing info back."""
|
|
conn = self.connection
|
|
conn.executemany(
|
|
"""
|
|
insert into feed(id, url, active)
|
|
values(?, ?, 1)
|
|
on conflict(id) do update set url=?, active=?, etag=?, modified=?
|
|
""",
|
|
(
|
|
(f.id, f.url, f.url, 1 if f.active else 0, f.etag, f.modified)
|
|
for f in feeds.values()
|
|
),
|
|
)
|
|
|
|
conn.executemany(
|
|
"""
|
|
insert into post(id, feed_id, content, date, link, title)
|
|
values(?, ?, ?, ?, ?, ?)
|
|
on conflict do nothing
|
|
""",
|
|
(
|
|
(p.id, f.id, p.content, p.date, p.link, p.title)
|
|
for f in feeds.values()
|
|
for p in f.posts
|
|
),
|
|
)
|
|
|
|
sql = "select id, url, active from feed"
|
|
for row in conn.execute(sql):
|
|
id, url, active = row
|
|
if id not in feeds:
|
|
feeds[id] = Feed(id, url)
|
|
else:
|
|
if active:
|
|
if feeds[id].url != url:
|
|
log.warning(f"Feed URL changed: {id}: {url}")
|
|
feeds[id].url = url
|
|
else:
|
|
log.warning(f"Feed is marked inactive: {id}")
|
|
del feeds[id]
|
|
|
|
post_ids = {f.id: f.post_ids for f in feeds.values()}
|
|
sql = """
|
|
select post.id, feed_id from post
|
|
join feed on feed.id=feed_id
|
|
where feed.active=1
|
|
"""
|
|
for row in conn.execute(sql):
|
|
post_id, feed_id = row
|
|
if post_id not in post_ids[feed_id]:
|
|
post_ids[feed_id].add(post_id)
|
|
feeds[feed_id].posts.append(Post(post_id))
|
|
|
|
def posts(self, feed_id, post_ids) -> Sequence[Post]:
|
|
qs = ",".join(["?"] * len(post_ids))
|
|
sql = f"""
|
|
select id, content, date, link, title from post
|
|
where feed_id=? and id in ({qs})
|
|
"""
|
|
conn = self.connection
|
|
posts = [Post(*row) for row in conn.execute(sql, (feed_id, *post_ids))]
|
|
for post in posts:
|
|
if post.date is not None:
|
|
post.date = datetime.fromisoformat(post.date)
|
|
return posts
|