"""Bluesky repo storage base class and in-memory implementation.
Lightly based on:
https://github.com/bluesky-social/atproto/blob/main/packages/repo/src/storage/repo-storage.ts
"""
from collections import namedtuple
import copy
from enum import auto, Enum
import itertools
import logging
import dag_cbor
from multiformats import CID, multicodec, multihash
from . import mst as mst_mod
from .repo import Write
from .server import server
from . import util
from .util import dag_cbor_cid, DEACTIVATED, tid_to_int, TOMBSTONED, InactiveRepo
SUBSCRIBE_REPOS_NSID = 'com.atproto.sync.subscribeRepos'
# Sync 1.1 commit size limits
# https://github.com/bluesky-social/proposals/blob/main/0006-sync-iteration%2FREADME.md#commit-size-limits
MAX_RECORD_SIZE_BYTES = 1_000_000 # 1 MB
MAX_COMMIT_BLOCKS_BYTES = 2_000_000 # 2 MB
MAX_EVENT_SIZE_BYTES = 5_000_000 # 5 MB
MAX_OPERATIONS_PER_COMMIT = 200
logger = logging.getLogger(__name__)
[docs]
class Action(Enum):
"""Used in :meth:`Storage.commit`.
TODO: switch to StrEnum once we can require Python 3.11.
"""
CREATE = auto()
UPDATE = auto()
DELETE = auto()
# TODO: Should this be a subclass of Block?
# TODO: generalize to handle other events
CommitData = namedtuple('CommitData', [
'commit', # Block
'blocks', # dict of CID to Block
'prev', # CID or None
], defaults=[None]) # prev
CommitOp = namedtuple('CommitOp', [ # for subscribeRepos
'action', # Action
'path', # str
'cid', # CID, or None for DELETE
'prev_cid', # previous CID for UPDATE and DELETE operations, None for CREATE
], defaults=[None, None]) # cid, prev_cid
# commit record format is:
# https://atproto.com/specs/repository#commit-objects
#
# {
# 'version': 3,
# 'did': [repo],
# 'rev': [str, TID],
# 'data': [CID],
# 'prev': [CID or None],
# 'sig': [bytes],
# }
[docs]
class Block:
r"""An ATProto block: a record, :class:`MST` entry, commit, or other event.
Can start from either encoded bytes or decoded object, with or without
:class:`CID`. Decodes, encodes, and generates :class:`CID` lazily, on
demand, on attribute access.
Events should have a fully-qualified ``$type`` field that's one of the
``message`` types in ``com.atproto.sync.subscribeRepos``, eg
``com.atproto.sync.subscribeRepos#tombstone``.
Based on :class:`carbox.car.Block`.
Attributes:
cid (CID): lazy-loaded (dynamic property)
decoded (dict): decoded object (dynamic property)
encoded (bytes): DAG-CBOR encoded data (dynamic property)
seq (int): ``com.atproto.sync.subscribeRepos`` sequence number
ops (list): :class:`CommitOp`\s if this is a commit, otherwise None
time (datetime): when this block was first created
repo (str): DID of a repo that includes this block. Occasionally, blocks
may be included in more than one repo, so this may be *any* repo that
includes it. In practice, it's often the first or last repo that
included it.
"""
def __init__(self, *, cid=None, decoded=None, encoded=None, seq=None,
ops=None, time=None, repo=None):
"""Constructor.
Args:
cid (CID): optional
decoded (dict): optional
encoded (bytes): optional
"""
assert encoded or decoded
self._cid = cid
self._encoded = encoded
self._decoded = decoded
self.seq = seq
self.ops = ops
self.time = time or util.now()
self.repo = repo
def __str__(self):
return f'<Block: {self.cid}>'
@property
def cid(self):
if self._cid is None:
digest = multihash.digest(self.encoded, 'sha2-256')
self._cid = CID('base58btc', 1, 'dag-cbor', digest)
return self._cid
@property
def encoded(self):
if self._encoded is None:
self._encoded = dag_cbor.encode(self.decoded)
return self._encoded
@property
def decoded(self):
if self._decoded is None:
self._decoded = dag_cbor.decode(self.encoded)
return self._decoded
def __eq__(self, other):
"""Compares by CID only."""
return self.cid == other.cid
def __hash__(self):
return hash(self.cid)
[docs]
class Sequences:
"""Abstract base class for managing sequence numbers for event streams.
...eg the ``com.atproto.sync.subscribeRepos`` firehose.
Background: https://atproto.com/specs/event-stream#sequence-numbers
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
[docs]
def allocate(self, nsid):
"""Generates and returns a sequence number for a given NSID.
Args:
nsid (str): subscription XRPC method this sequence number is for
Returns:
int:
"""
raise NotImplementedError()
[docs]
def last(self, nsid):
"""Returns the last (highest) allocated sequence number for a given NSID.
...or None if no sequence number has ever been allocated for this NSID.
Args:
nsid (str): subscription XRPC method this sequence number is for
Returns:
int or None:
"""
raise NotImplementedError()
[docs]
class Storage:
"""Abstract base class for storing nodes: records, MST entries, commits, etc.
Concrete subclasses should implement this on top of physical storage,
eg database, filesystem, in memory.
Attributes:
sequences (class): Sequences subclass for allocating sequence numbers
"""
def __init__(self, *, sequences, **kwargs):
"""Constructor.
Args:
sequences (Sequences)
"""
super().__init__(**kwargs)
assert sequences
self.sequences = sequences
[docs]
def create_repo(self, repo):
"""Stores a new repo's metadata in storage.
Only stores the repo's handle and head commit :class:`CID`, not blocks!
If the repo already exists in storage, this should update it instead of
failing.
Args:
repo (Repo)
"""
raise NotImplementedError()
[docs]
def load_repo(self, did_or_handle):
"""Loads a repo from storage.
Args:
did_or_handle (str): optional
Returns:
Repo, or None if the did or handle wasn't found:
"""
raise NotImplementedError()
[docs]
def store_repo(self, repo):
"""Writes a repo to storage.
Right now only writes some metadata:
* handle
* status
* head
Args:
repo (Repo)
"""
raise NotImplementedError()
[docs]
def load_repos(self, after=None, limit=500):
"""Loads multiple repos from storage.
Repos are returned in lexicographic order of their DIDs, ascending.
Tombstoned repos are included.
Args:
after (str): optional DID to start at, *exclusive*
limit (int): maximum number of repos to return
Returns:
sequence of Repo:
"""
raise NotImplementedError()
[docs]
def deactivate_repo(self, repo):
"""Marks a repo as deactivated.
* Stores a ``com.atproto.sync.subscribeRepos#account`` block with its
own sequence number.
* If :attr:`Repo.callback` is populated, calls it with the
``com.atproto.sync.subscribeRepos#account`` message.
* Calls :meth:`Repo._set_status` to mark the repo as deactivated in storage.
After this, any attempt to write to this repo will raise
:class:`InactiveRepo`.
Args:
repo (Repo)
"""
self._set_repo_status(repo, DEACTIVATED)
block = self.write_event(repo=repo, type='account',
active=False, status='deactivated')
[docs]
def activate_repo(self, repo):
"""Marks a repo as active.
Only needed after deactivating. Does nothing if the repo is tombstoned.
* Stores a ``com.atproto.sync.subscribeRepos#account`` block with its
own sequence number.
* If :attr:`Repo.callback` is populated, calls it with the
``com.atproto.sync.subscribeRepos#account`` message.
* Calls :meth:`Repo._set_status` to mark the repo as active in storage.
Args:
repo (Repo)
"""
self._set_repo_status(repo, None)
block = self.write_event(repo=repo, type='account', active=True)
[docs]
def tombstone_repo(self, repo):
"""Marks a repo as tombstoned.
* Stores a ``com.atproto.sync.subscribeRepos#tombstone`` block with its
own sequence number.
* If :attr:`Repo.callback` is populated, calls it with the
``com.atproto.sync.subscribeRepos#tombstone`` message.
* Calls :meth:`Repo._set_status` to mark the repo as deactivated in storage.
After this, any attempt to write to this repo will raise
:class:`InactiveRepo`.
Args:
repo (Repo)
"""
self._set_repo_status(repo, TOMBSTONED)
block = self.write_event(repo=repo, type='tombstone')
def _tombstone_repo(self, repo):
"""Marks a repo as tombstoned in storage.
Args:
repo (Repo)
"""
raise NotImplementedError()
[docs]
def read(self, cid):
"""Reads a node from storage.
Args:
cid (CID)
Returns:
Block, or None if not found:
"""
raise NotImplementedError()
[docs]
def read_many(self, cids, require_all=True):
"""Batch read multiple nodes from storage.
Args:
cids (sequence of CID)
require_all (bool): whether to assert that all cids are found
Returns:
dict: {:class:`CID`: :class:`Block` or None if not found}
"""
raise NotImplementedError()
[docs]
def read_blocks_by_seq(self, start=0, repo=None):
"""Batch read blocks from storage by ``subscribeRepos`` sequence number.
Args:
seq (int): optional ``subscribeRepos`` sequence number to start from.
Defaults to 0.
repo (str): optional repo DID. If not provided, all repos are included.
Returns:
iterable or generator: all :class:`Block` s starting from ``seq``,
inclusive, in ascending ``seq`` order
"""
raise NotImplementedError()
[docs]
def read_events_by_seq(self, start=0, repo=None):
"""Batch read commits and other events by ``subscribeRepos`` sequence number.
Args:
start (int): optional ``subscribeRepos`` sequence number to start from,
inclusive. Defaults to 0.
repo (str): optional repo DID. If not provided, all repos are included.
Returns:
generator: generator of :class:`CommitData` for commits and dict
messages for other events, starting from ``seq``, inclusive, in
ascending ``seq`` order
"""
assert start >= 0
seq = commit_block = blocks = None
def make_commit():
for op in commit_block.ops:
if (op.action in (Action.CREATE, Action.UPDATE)
and op.cid not in blocks):
record = self.read(op.cid)
assert record
blocks[op.cid] = record
mst_root = commit_block.decoded['data']
if mst_root not in blocks:
blocks[mst_root] = self.read(mst_root)
return CommitData(blocks=blocks, commit=commit_block,
prev=commit_block.decoded.get('prev'))
for block in self.read_blocks_by_seq(start=start, repo=repo):
assert block.seq
if block.seq != seq: # switching to a new commit's blocks
if commit_block:
yield make_commit()
else:
# we shouldn't have any dangling blocks that we don't serve
assert not blocks
seq = block.seq
blocks = {} # maps CID to Block
commit_block = None
if block.decoded.get('$type', '').startswith(
'com.atproto.sync.subscribeRepos#'): # non-commit message
yield block.decoded
continue
blocks[block.cid] = block
commit_fields = ['version', 'did', 'rev', 'prev', 'data', 'sig']
if block.decoded.keys() == set(commit_fields):
commit_block = block
# final commit
if blocks:
assert commit_block, f'seq {seq}'
yield make_commit()
[docs]
def has(self, cid):
"""Checks if a given :class:`CID` is currently stored.
Args:
cid (CID)
Returns:
bool:
"""
raise NotImplementedError()
[docs]
def write(self, repo_did, obj, seq=None):
"""Writes a node to storage.
Args:
repo_did (str):
obj (dict): a record, commit, serialized :class:`MST` node, or
`subscribeRepos` event/message
seq (int or None): sequence number. If not provided, a new one will be
allocated.
Returns:
Block:
Raises:
InactiveError: if the repo is not active
"""
raise NotImplementedError()
[docs]
def write_event(self, repo, type, **kwargs):
"""Writes a ``subscribeRepos`` event to storage.
Args:
repo (Repo)
type (str): ``account`` or ``identity``
kwargs: included in the event, eg ``active``, `status``
Returns:
Block:
Raises:
InactiveError: if the repo is not active
"""
assert type in ('account', 'identity', 'sync', 'tombstone'), type
seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID)
block = self.write(repo.did, {
'$type': f'com.atproto.sync.subscribeRepos#{type}',
'seq': seq,
'did': repo.did,
'time': util.now().isoformat(),
**kwargs,
}, seq=seq)
if repo.callback:
repo.callback(data=block.decoded)
return block
[docs]
def write_blocks(self, blocks):
"""Batch write blocks to storage.
Does not allocate sequence numbers!
Args:
blocks (sequence of :class:`Block`)
"""
raise NotImplementedError()
[docs]
def commit(self, repo, writes, repo_did=None):
"""Commits zero or more writes to storage.
Allocates a new sequence number and uses it for all blocks in the commit.
Args:
repo (Repo)
writes (Write or sequence of Write)
repo_did (str): optional, used if this is the repo's first commit
Returns:
CommitData:
Raises:
InactiveError: if the repo is not active
ValueError: if the commit is invalid, eg the path for an update or delete
doesn't currently exist
"""
seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID)
try:
commit_data = self._commit(repo, writes, seq, repo_did=repo_did)
if repo.callback:
repo.callback(data=commit_data)
except BaseException:
if repo.callback:
repo.callback(lost_seq=seq)
raise
return commit_data
def _commit(self, repo, writes, seq, repo_did=None):
"""Separate from :meth:`commit` so that subclasses can put it in a tx."""
assert seq
if repo.status:
raise InactiveRepo(repo.did, repo.status)
orig_repo = repo
if repo_did:
# this is the initial empty commit for creating a new repo
if repo.head:
# this must be a transaction retry, and head was set by a previous
# attempt, below
assert repo.head.decoded.get('did') == repo_did, repo.head.decoded
assert not repo.head.decoded.get('prev'), repo.head.decoded
repo.head = None
assert not repo.did
prev = None
else:
# this is an existing repo with at least one commit
assert repo.did
repo_did = repo.did
repo = self.load_repo(repo_did)
prev = repo.head.cid # for the new commit
commit_blocks = {} # maps CID to Block
assert writes is not None
if isinstance(writes, Write):
writes = [writes]
if len(writes) > MAX_OPERATIONS_PER_COMMIT:
raise ValueError(f'Too many operations ({len(writes)}), max is {MAX_OPERATIONS_PER_COMMIT}')
ops = []
for write in copy.copy(writes):
assert isinstance(write, Write), type(write)
path = f'{write.collection}/{write.rkey}'
# sync v1.1: for UPDATE and DELETE, load the previous record's CID
# https://github.com/bluesky-social/proposals/tree/main/0006-sync-iteration#commit-validation-mst-operation-inversion
prev_cid = None
if write.action in (Action.UPDATE, Action.DELETE):
if not (prev_cid := repo.mst.get(path)):
raise ValueError(f"{path} doesn't exist in repo")
if write.action == Action.DELETE:
logger.debug('deleting from MST')
repo.mst = repo.mst.delete(path)
logger.debug(' done')
ops.append(CommitOp(action=Action.DELETE, path=path,
cid=None, prev_cid=prev_cid))
continue
# raises ValidationError if it doesn't validate
assert write.record is not None
server.validate(write.record.get('$type'), 'record', write.record)
block = Block(decoded=write.record, repo=repo_did, seq=seq)
if len(block.encoded) > MAX_RECORD_SIZE_BYTES:
raise ValueError(f'Record {path} size {len(block.encoded)} bytes exceeds max {MAX_RECORD_SIZE_BYTES}')
commit_blocks[block.cid] = block
op = CommitOp(action=write.action, path=path,
cid=block.cid, prev_cid=prev_cid)
if write.action == Action.CREATE:
logger.debug('adding to MST')
repo.mst = repo.mst.add(path, block.cid)
logger.debug(' done')
ops.append(op)
else:
assert write.action == Action.UPDATE
orig_pointer = repo.mst.get_pointer()
logger.debug('updating MST')
repo.mst = repo.mst.update(path, block.cid)
logger.debug(' done')
if repo.mst.get_pointer() != orig_pointer:
# no-op updates are invalid in ATProto, so only include this
# update operation if it changes the the record and MST.
# https://github.com/snarfed/arroba/issues/52#issuecomment-2825755142
ops.append(op)
logger.debug('loading unstored MST blocks')
root, unstored_blocks = repo.mst.get_unstored_blocks()
logger.debug(' done')
for block in unstored_blocks.values():
block.repo = repo_did
block.seq = seq
commit_blocks.update(unstored_blocks)
# construct commit
commit = util.sign({
'did': repo_did,
'version': 3,
# reuse subscribeRepos sequence number as rev
# https://github.com/bluesky-social/atproto/discussions/1607
'rev': util.int_to_tid(seq, clock_id=0),
'prev': prev,
'data': root,
}, repo.signing_key)
commit_block = Block(decoded=commit, repo=repo_did, seq=seq, ops=ops)
commit_blocks[commit_block.cid] = commit_block
commit_data = CommitData(commit=commit_block, prev=prev, blocks=commit_blocks)
# only add new blocks so we don't wipe out any existing blocks' sequence
# numbers. (occasionally we see existing blocks recur, eg MST nodes.)
logger.debug(f'writing {len(commit_data.blocks)} new blocks')
self.write_blocks(commit_data.blocks.values())
logger.debug(' done')
# update repo head
if repo.did:
repo.head = commit_data.commit
logger.info(f'Updating {repo.did} head {repo.head.cid}')
self.store_repo(repo)
logger.debug(' done')
orig_repo.mst = repo.mst
orig_repo.head = commit_block
return commit_data
[docs]
class MemorySequences(Sequences):
"""In memory sequence numbers.
Attributes:
sequences (dict): {str NSID: int next sequence number}
"""
def __init__(self):
self.sequences = {}
def allocate(self, nsid):
assert nsid
next = self.sequences.setdefault(nsid, 1)
logger.info(f'Allocated seq {next}')
self.sequences[nsid] += 1
return next
def last(self, nsid):
assert nsid
if next := self.sequences.get(nsid):
return next - 1
[docs]
class MemoryStorage(Storage):
"""In memory storage implementation.
Attributes:
repos (dict mapping str DID to :class:`Repo`)
blocks (dict): {:class:`CID`: :class:`Block`}
"""
repos = None
blocks = None
def __init__(self, *, sequences=None):
"""
Args:
sequences (Sequences): optional; defaults to a :class:`MemorySequences`
"""
super().__init__(sequences=sequences or MemorySequences())
self.blocks = {}
self.repos = {}
def create_repo(self, repo):
self.repos[repo.did] = repo
def load_repo(self, did_or_handle):
assert did_or_handle
for repo in self.repos.values():
if did_or_handle in (repo.did, repo.handle):
return repo
def store_repo(self, repo):
self.repos[repo.did] = repo
def load_repos(self, after=None, limit=500):
it = iter(sorted(self.repos.values(), key=lambda repo: repo.did))
if after:
it = itertools.dropwhile(lambda repo: repo.did <= after, it)
return list(itertools.islice(it, limit))
def _set_repo_status(self, repo, status):
repo.status = status
def read(self, cid):
return self.blocks.get(cid)
def read_many(self, cids, require_all=True):
cids = list(cids)
found = {cid: self.blocks.get(cid) for cid in cids}
if require_all:
assert len(found) == len(cids), (len(found), len(cids))
return found
def read_blocks_by_seq(self, start=0, repo=None):
assert start >= 0
return sorted((b for b in self.blocks.values()
if b.seq >= start and (not repo or b.repo == repo)),
key=lambda b: b.seq)
def has(self, cid):
return cid in self.blocks
def write(self, repo_did, obj, seq=None):
if seq is None:
seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID)
block = Block(decoded=obj, seq=seq, repo=repo_did)
self.blocks.setdefault(block.cid, block)
return block
def write_blocks(self, blocks):
for block in blocks:
self.blocks.setdefault(block.cid, block)