Source code for arroba.storage

"""Bluesky repo storage base class and in-memory implementation.

Lightly based on:
https://github.com/bluesky-social/atproto/blob/main/packages/repo/src/storage/repo-storage.ts
"""
from collections import namedtuple
import copy
from enum import auto, Enum
import itertools
import logging

import dag_cbor
from multiformats import CID, multicodec, multihash

from . import mst as mst_mod
from .repo import Write
from .server import server
from . import util
from .util import dag_cbor_cid, DEACTIVATED, tid_to_int, TOMBSTONED, InactiveRepo

SUBSCRIBE_REPOS_NSID = 'com.atproto.sync.subscribeRepos'

# Sync 1.1 commit size limits
# https://github.com/bluesky-social/proposals/blob/main/0006-sync-iteration%2FREADME.md#commit-size-limits
MAX_RECORD_SIZE_BYTES = 1_000_000  # 1 MB
MAX_COMMIT_BLOCKS_BYTES = 2_000_000  # 2 MB
MAX_EVENT_SIZE_BYTES = 5_000_000  # 5 MB
MAX_OPERATIONS_PER_COMMIT = 200

logger = logging.getLogger(__name__)


[docs] class Action(Enum): """Used in :meth:`Storage.commit`. TODO: switch to StrEnum once we can require Python 3.11. """ CREATE = auto() UPDATE = auto() DELETE = auto()
# TODO: Should this be a subclass of Block? # TODO: generalize to handle other events CommitData = namedtuple('CommitData', [ 'commit', # Block 'blocks', # dict of CID to Block 'prev', # CID or None ], defaults=[None]) # prev CommitOp = namedtuple('CommitOp', [ # for subscribeRepos 'action', # Action 'path', # str 'cid', # CID, or None for DELETE 'prev_cid', # previous CID for UPDATE and DELETE operations, None for CREATE ], defaults=[None, None]) # cid, prev_cid # commit record format is: # https://atproto.com/specs/repository#commit-objects # # { # 'version': 3, # 'did': [repo], # 'rev': [str, TID], # 'data': [CID], # 'prev': [CID or None], # 'sig': [bytes], # }
[docs] class Block: r"""An ATProto block: a record, :class:`MST` entry, commit, or other event. Can start from either encoded bytes or decoded object, with or without :class:`CID`. Decodes, encodes, and generates :class:`CID` lazily, on demand, on attribute access. Events should have a fully-qualified ``$type`` field that's one of the ``message`` types in ``com.atproto.sync.subscribeRepos``, eg ``com.atproto.sync.subscribeRepos#tombstone``. Based on :class:`carbox.car.Block`. Attributes: cid (CID): lazy-loaded (dynamic property) decoded (dict): decoded object (dynamic property) encoded (bytes): DAG-CBOR encoded data (dynamic property) seq (int): ``com.atproto.sync.subscribeRepos`` sequence number ops (list): :class:`CommitOp`\s if this is a commit, otherwise None time (datetime): when this block was first created repo (str): DID of a repo that includes this block. Occasionally, blocks may be included in more than one repo, so this may be *any* repo that includes it. In practice, it's often the first or last repo that included it. """ def __init__(self, *, cid=None, decoded=None, encoded=None, seq=None, ops=None, time=None, repo=None): """Constructor. Args: cid (CID): optional decoded (dict): optional encoded (bytes): optional """ assert encoded or decoded self._cid = cid self._encoded = encoded self._decoded = decoded self.seq = seq self.ops = ops self.time = time or util.now() self.repo = repo def __str__(self): return f'<Block: {self.cid}>' @property def cid(self): if self._cid is None: digest = multihash.digest(self.encoded, 'sha2-256') self._cid = CID('base58btc', 1, 'dag-cbor', digest) return self._cid @property def encoded(self): if self._encoded is None: self._encoded = dag_cbor.encode(self.decoded) return self._encoded @property def decoded(self): if self._decoded is None: self._decoded = dag_cbor.decode(self.encoded) return self._decoded def __eq__(self, other): """Compares by CID only.""" return self.cid == other.cid def __hash__(self): return hash(self.cid)
[docs] class Sequences: """Abstract base class for managing sequence numbers for event streams. ...eg the ``com.atproto.sync.subscribeRepos`` firehose. Background: https://atproto.com/specs/event-stream#sequence-numbers """ def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def allocate(self, nsid): """Generates and returns a sequence number for a given NSID. Args: nsid (str): subscription XRPC method this sequence number is for Returns: int: """ raise NotImplementedError()
[docs] def last(self, nsid): """Returns the last (highest) allocated sequence number for a given NSID. ...or None if no sequence number has ever been allocated for this NSID. Args: nsid (str): subscription XRPC method this sequence number is for Returns: int or None: """ raise NotImplementedError()
[docs] class Storage: """Abstract base class for storing nodes: records, MST entries, commits, etc. Concrete subclasses should implement this on top of physical storage, eg database, filesystem, in memory. Attributes: sequences (class): Sequences subclass for allocating sequence numbers """ def __init__(self, *, sequences, **kwargs): """Constructor. Args: sequences (Sequences) """ super().__init__(**kwargs) assert sequences self.sequences = sequences
[docs] def create_repo(self, repo): """Stores a new repo's metadata in storage. Only stores the repo's handle and head commit :class:`CID`, not blocks! If the repo already exists in storage, this should update it instead of failing. Args: repo (Repo) """ raise NotImplementedError()
[docs] def load_repo(self, did_or_handle): """Loads a repo from storage. Args: did_or_handle (str): optional Returns: Repo, or None if the did or handle wasn't found: """ raise NotImplementedError()
[docs] def store_repo(self, repo): """Writes a repo to storage. Right now only writes some metadata: * handle * status * head Args: repo (Repo) """ raise NotImplementedError()
[docs] def load_repos(self, after=None, limit=500): """Loads multiple repos from storage. Repos are returned in lexicographic order of their DIDs, ascending. Tombstoned repos are included. Args: after (str): optional DID to start at, *exclusive* limit (int): maximum number of repos to return Returns: sequence of Repo: """ raise NotImplementedError()
[docs] def deactivate_repo(self, repo): """Marks a repo as deactivated. * Stores a ``com.atproto.sync.subscribeRepos#account`` block with its own sequence number. * If :attr:`Repo.callback` is populated, calls it with the ``com.atproto.sync.subscribeRepos#account`` message. * Calls :meth:`Repo._set_status` to mark the repo as deactivated in storage. After this, any attempt to write to this repo will raise :class:`InactiveRepo`. Args: repo (Repo) """ self._set_repo_status(repo, DEACTIVATED) block = self.write_event(repo=repo, type='account', active=False, status='deactivated')
[docs] def activate_repo(self, repo): """Marks a repo as active. Only needed after deactivating. Does nothing if the repo is tombstoned. * Stores a ``com.atproto.sync.subscribeRepos#account`` block with its own sequence number. * If :attr:`Repo.callback` is populated, calls it with the ``com.atproto.sync.subscribeRepos#account`` message. * Calls :meth:`Repo._set_status` to mark the repo as active in storage. Args: repo (Repo) """ self._set_repo_status(repo, None) block = self.write_event(repo=repo, type='account', active=True)
[docs] def tombstone_repo(self, repo): """Marks a repo as tombstoned. * Stores a ``com.atproto.sync.subscribeRepos#tombstone`` block with its own sequence number. * If :attr:`Repo.callback` is populated, calls it with the ``com.atproto.sync.subscribeRepos#tombstone`` message. * Calls :meth:`Repo._set_status` to mark the repo as deactivated in storage. After this, any attempt to write to this repo will raise :class:`InactiveRepo`. Args: repo (Repo) """ self._set_repo_status(repo, TOMBSTONED) block = self.write_event(repo=repo, type='tombstone')
def _tombstone_repo(self, repo): """Marks a repo as tombstoned in storage. Args: repo (Repo) """ raise NotImplementedError()
[docs] def read(self, cid): """Reads a node from storage. Args: cid (CID) Returns: Block, or None if not found: """ raise NotImplementedError()
[docs] def read_many(self, cids, require_all=True): """Batch read multiple nodes from storage. Args: cids (sequence of CID) require_all (bool): whether to assert that all cids are found Returns: dict: {:class:`CID`: :class:`Block` or None if not found} """ raise NotImplementedError()
[docs] def read_blocks_by_seq(self, start=0, repo=None): """Batch read blocks from storage by ``subscribeRepos`` sequence number. Args: seq (int): optional ``subscribeRepos`` sequence number to start from. Defaults to 0. repo (str): optional repo DID. If not provided, all repos are included. Returns: iterable or generator: all :class:`Block` s starting from ``seq``, inclusive, in ascending ``seq`` order """ raise NotImplementedError()
[docs] def read_events_by_seq(self, start=0, repo=None): """Batch read commits and other events by ``subscribeRepos`` sequence number. Args: start (int): optional ``subscribeRepos`` sequence number to start from, inclusive. Defaults to 0. repo (str): optional repo DID. If not provided, all repos are included. Returns: generator: generator of :class:`CommitData` for commits and dict messages for other events, starting from ``seq``, inclusive, in ascending ``seq`` order """ assert start >= 0 seq = commit_block = blocks = None def make_commit(): for op in commit_block.ops: if (op.action in (Action.CREATE, Action.UPDATE) and op.cid not in blocks): record = self.read(op.cid) assert record blocks[op.cid] = record mst_root = commit_block.decoded['data'] if mst_root not in blocks: blocks[mst_root] = self.read(mst_root) return CommitData(blocks=blocks, commit=commit_block, prev=commit_block.decoded.get('prev')) for block in self.read_blocks_by_seq(start=start, repo=repo): assert block.seq if block.seq != seq: # switching to a new commit's blocks if commit_block: yield make_commit() else: # we shouldn't have any dangling blocks that we don't serve assert not blocks seq = block.seq blocks = {} # maps CID to Block commit_block = None if block.decoded.get('$type', '').startswith( 'com.atproto.sync.subscribeRepos#'): # non-commit message yield block.decoded continue blocks[block.cid] = block commit_fields = ['version', 'did', 'rev', 'prev', 'data', 'sig'] if block.decoded.keys() == set(commit_fields): commit_block = block # final commit if blocks: assert commit_block, f'seq {seq}' yield make_commit()
[docs] def has(self, cid): """Checks if a given :class:`CID` is currently stored. Args: cid (CID) Returns: bool: """ raise NotImplementedError()
[docs] def write(self, repo_did, obj, seq=None): """Writes a node to storage. Args: repo_did (str): obj (dict): a record, commit, serialized :class:`MST` node, or `subscribeRepos` event/message seq (int or None): sequence number. If not provided, a new one will be allocated. Returns: Block: Raises: InactiveError: if the repo is not active """ raise NotImplementedError()
[docs] def write_event(self, repo, type, **kwargs): """Writes a ``subscribeRepos`` event to storage. Args: repo (Repo) type (str): ``account`` or ``identity`` kwargs: included in the event, eg ``active``, `status`` Returns: Block: Raises: InactiveError: if the repo is not active """ assert type in ('account', 'identity', 'sync', 'tombstone'), type seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID) block = self.write(repo.did, { '$type': f'com.atproto.sync.subscribeRepos#{type}', 'seq': seq, 'did': repo.did, 'time': util.now().isoformat(), **kwargs, }, seq=seq) if repo.callback: repo.callback(data=block.decoded) return block
[docs] def write_blocks(self, blocks): """Batch write blocks to storage. Does not allocate sequence numbers! Args: blocks (sequence of :class:`Block`) """ raise NotImplementedError()
[docs] def commit(self, repo, writes, repo_did=None): """Commits zero or more writes to storage. Allocates a new sequence number and uses it for all blocks in the commit. Args: repo (Repo) writes (Write or sequence of Write) repo_did (str): optional, used if this is the repo's first commit Returns: CommitData: Raises: InactiveError: if the repo is not active ValueError: if the commit is invalid, eg the path for an update or delete doesn't currently exist """ seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID) try: commit_data = self._commit(repo, writes, seq, repo_did=repo_did) if repo.callback: repo.callback(data=commit_data) except BaseException: if repo.callback: repo.callback(lost_seq=seq) raise return commit_data
def _commit(self, repo, writes, seq, repo_did=None): """Separate from :meth:`commit` so that subclasses can put it in a tx.""" assert seq if repo.status: raise InactiveRepo(repo.did, repo.status) orig_repo = repo if repo_did: # this is the initial empty commit for creating a new repo if repo.head: # this must be a transaction retry, and head was set by a previous # attempt, below assert repo.head.decoded.get('did') == repo_did, repo.head.decoded assert not repo.head.decoded.get('prev'), repo.head.decoded repo.head = None assert not repo.did prev = None else: # this is an existing repo with at least one commit assert repo.did repo_did = repo.did repo = self.load_repo(repo_did) prev = repo.head.cid # for the new commit commit_blocks = {} # maps CID to Block assert writes is not None if isinstance(writes, Write): writes = [writes] if len(writes) > MAX_OPERATIONS_PER_COMMIT: raise ValueError(f'Too many operations ({len(writes)}), max is {MAX_OPERATIONS_PER_COMMIT}') ops = [] for write in copy.copy(writes): assert isinstance(write, Write), type(write) path = f'{write.collection}/{write.rkey}' # sync v1.1: for UPDATE and DELETE, load the previous record's CID # https://github.com/bluesky-social/proposals/tree/main/0006-sync-iteration#commit-validation-mst-operation-inversion prev_cid = None if write.action in (Action.UPDATE, Action.DELETE): if not (prev_cid := repo.mst.get(path)): raise ValueError(f"{path} doesn't exist in repo") if write.action == Action.DELETE: logger.debug('deleting from MST') repo.mst = repo.mst.delete(path) logger.debug(' done') ops.append(CommitOp(action=Action.DELETE, path=path, cid=None, prev_cid=prev_cid)) continue # raises ValidationError if it doesn't validate assert write.record is not None server.validate(write.record.get('$type'), 'record', write.record) block = Block(decoded=write.record, repo=repo_did, seq=seq) if len(block.encoded) > MAX_RECORD_SIZE_BYTES: raise ValueError(f'Record {path} size {len(block.encoded)} bytes exceeds max {MAX_RECORD_SIZE_BYTES}') commit_blocks[block.cid] = block op = CommitOp(action=write.action, path=path, cid=block.cid, prev_cid=prev_cid) if write.action == Action.CREATE: logger.debug('adding to MST') repo.mst = repo.mst.add(path, block.cid) logger.debug(' done') ops.append(op) else: assert write.action == Action.UPDATE orig_pointer = repo.mst.get_pointer() logger.debug('updating MST') repo.mst = repo.mst.update(path, block.cid) logger.debug(' done') if repo.mst.get_pointer() != orig_pointer: # no-op updates are invalid in ATProto, so only include this # update operation if it changes the the record and MST. # https://github.com/snarfed/arroba/issues/52#issuecomment-2825755142 ops.append(op) logger.debug('loading unstored MST blocks') root, unstored_blocks = repo.mst.get_unstored_blocks() logger.debug(' done') for block in unstored_blocks.values(): block.repo = repo_did block.seq = seq commit_blocks.update(unstored_blocks) # construct commit commit = util.sign({ 'did': repo_did, 'version': 3, # reuse subscribeRepos sequence number as rev # https://github.com/bluesky-social/atproto/discussions/1607 'rev': util.int_to_tid(seq, clock_id=0), 'prev': prev, 'data': root, }, repo.signing_key) commit_block = Block(decoded=commit, repo=repo_did, seq=seq, ops=ops) commit_blocks[commit_block.cid] = commit_block commit_data = CommitData(commit=commit_block, prev=prev, blocks=commit_blocks) # only add new blocks so we don't wipe out any existing blocks' sequence # numbers. (occasionally we see existing blocks recur, eg MST nodes.) logger.debug(f'writing {len(commit_data.blocks)} new blocks') self.write_blocks(commit_data.blocks.values()) logger.debug(' done') # update repo head if repo.did: repo.head = commit_data.commit logger.info(f'Updating {repo.did} head {repo.head.cid}') self.store_repo(repo) logger.debug(' done') orig_repo.mst = repo.mst orig_repo.head = commit_block return commit_data
[docs] class MemorySequences(Sequences): """In memory sequence numbers. Attributes: sequences (dict): {str NSID: int next sequence number} """ def __init__(self): self.sequences = {} def allocate(self, nsid): assert nsid next = self.sequences.setdefault(nsid, 1) logger.info(f'Allocated seq {next}') self.sequences[nsid] += 1 return next def last(self, nsid): assert nsid if next := self.sequences.get(nsid): return next - 1
[docs] class MemoryStorage(Storage): """In memory storage implementation. Attributes: repos (dict mapping str DID to :class:`Repo`) blocks (dict): {:class:`CID`: :class:`Block`} """ repos = None blocks = None def __init__(self, *, sequences=None): """ Args: sequences (Sequences): optional; defaults to a :class:`MemorySequences` """ super().__init__(sequences=sequences or MemorySequences()) self.blocks = {} self.repos = {} def create_repo(self, repo): self.repos[repo.did] = repo def load_repo(self, did_or_handle): assert did_or_handle for repo in self.repos.values(): if did_or_handle in (repo.did, repo.handle): return repo def store_repo(self, repo): self.repos[repo.did] = repo def load_repos(self, after=None, limit=500): it = iter(sorted(self.repos.values(), key=lambda repo: repo.did)) if after: it = itertools.dropwhile(lambda repo: repo.did <= after, it) return list(itertools.islice(it, limit)) def _set_repo_status(self, repo, status): repo.status = status def read(self, cid): return self.blocks.get(cid) def read_many(self, cids, require_all=True): cids = list(cids) found = {cid: self.blocks.get(cid) for cid in cids} if require_all: assert len(found) == len(cids), (len(found), len(cids)) return found def read_blocks_by_seq(self, start=0, repo=None): assert start >= 0 return sorted((b for b in self.blocks.values() if b.seq >= start and (not repo or b.repo == repo)), key=lambda b: b.seq) def has(self, cid): return cid in self.blocks def write(self, repo_did, obj, seq=None): if seq is None: seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID) block = Block(decoded=obj, seq=seq, repo=repo_did) self.blocks.setdefault(block.cid, block) return block def write_blocks(self, blocks): for block in blocks: self.blocks.setdefault(block.cid, block)