Source code for arroba.datastore_storage

"""Google Cloud Datastore implementation of repo storage."""
from datetime import timezone
from functools import wraps
import json
import logging
import mimetypes
import requests

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
import dag_cbor
import dag_json
from google.cloud import ndb
from google.cloud.ndb.context import get_context
from google.cloud.ndb.exceptions import ContextError
from multiformats import CID, multicodec, multihash

from .mst import MST
from .repo import Repo
from . import storage
from .storage import Action, Block, Storage, SUBSCRIBE_REPOS_NSID
from .util import dag_cbor_cid, tid_to_int, TOMBSTONED, TombstonedRepo

logger = logging.getLogger(__name__)


[docs]class WriteOnce: """:class:`ndb.Property` mix-in, prevents changing it once it's set.""" def _set_value(self, entity, value): existing = self._get_value(entity) if existing is not None and value != existing: raise ndb.ReadonlyPropertyError(f"{self._name} can't be changed") return super()._set_value(entity, value)
[docs]class JsonProperty(ndb.TextProperty): """Fork of ndb's that subclasses :class:`ndb.TextProperty` instead of :class:`ndb.BlobProperty`. This makes values show up as normal, human-readable, serialized JSON in the web console. https://github.com/googleapis/python-ndb/issues/874#issuecomment-1442753255 Duplicated in oauth-dropins/webutil: https://github.com/snarfed/webutil/blob/main/models.py """ def _validate(self, value): if not isinstance(value, dict): raise TypeError('JSON property must be a dict') def _to_base_type(self, value): as_str = json.dumps(value, separators=(',', ':'), ensure_ascii=True) return as_str.encode('ascii') def _from_base_type(self, value): if not isinstance(value, str): value = value.decode('ascii') return json.loads(value)
[docs]class ComputedJsonProperty(JsonProperty, ndb.ComputedProperty): """Custom :class:`ndb.ComputedProperty` for JSON values that stores them as strings. ...instead of like :class:`ndb.StructuredProperty`, with "entity" type, which bloats them unnecessarily in the datastore. """
[docs] def __init__(self, *args, **kwargs): kwargs['indexed'] = False super().__init__(*args, **kwargs)
[docs]class WriteOnceBlobProperty(WriteOnce, ndb.BlobProperty): pass
[docs]class CommitOp(ndb.Model): """Repo operations - creates, updates, deletes - included in a commit. Used in a :class:`StructuredProperty` inside :class:`AtpBlock`; not stored directly in the datastore. https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty """ action = ndb.StringProperty(required=True, choices=('create', 'update', 'delete')) path = ndb.StringProperty(required=True) cid = ndb.StringProperty() # unset for deletes
[docs]class AtpRepo(ndb.Model): r"""An ATProto repo. Key name is DID. Only stores the repo's metadata. Blocks are stored in :class:`AtpBlock`\s. Attributes: * handles (str): repeated, optional * head (str): CID * signing_key (str) * rotation_key (str) * status (str) """ handles = ndb.StringProperty(repeated=True) head = ndb.StringProperty(required=True) # TODO: add password hash? # these are both secp256k1 private keys, PEM-encoded bytes # https://atproto.com/specs/cryptography signing_key_pem = ndb.BlobProperty(required=True) # TODO: rename this recovery_key_pem? # https://discord.com/channels/1097580399187738645/1098725036917002302/1153447354003894372 rotation_key_pem = ndb.BlobProperty() status = ndb.StringProperty(choices=(TOMBSTONED,)) created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True) @property def signing_key(self): """(ec.EllipticCurvePrivateKey)""" return serialization.load_pem_private_key(self.signing_key_pem, password=None) @property def rotation_key(self): """(ec.EllipticCurvePrivateKey` or None)""" if self.rotation_key_pem: return serialization.load_pem_private_key(self.rotation_key_pem, password=None)
[docs]class AtpBlock(ndb.Model): """A data record, MST node, repo commit, or other event. Key name is the DAG-CBOR base32 CID of the data. Events should have a fully-qualified ``$type`` field that's one of the ``message`` types in ``com.atproto.sync.subscribeRepos``, eg ``com.atproto.sync.subscribeRepos#tombstone``. Properties: * encoded (bytes): DAG-CBOR encoded value * data (dict): DAG-JSON value, only used for human debugging * seq (int): sequence number for the subscribeRepos event stream """ repo = ndb.KeyProperty(AtpRepo, required=True) encoded = WriteOnceBlobProperty(required=True) seq = ndb.IntegerProperty(required=True) ops = ndb.StructuredProperty(CommitOp, repeated=True) created = ndb.DateTimeProperty(auto_now_add=True) @ComputedJsonProperty def decoded(self): return json.loads(dag_json.encode(dag_cbor.decode(self.encoded))) @property def cid(self): return CID.decode(self.key.id())
[docs] @staticmethod def create(*, repo_did, data, seq): """Writes a new AtpBlock to the datastore. If the block already exists in the datastore, leave it untouched. Notably, leave its sequence number as is, since it will be lower than this current sequence number. Args: repo_did (str): data (dict): value seq (int): Returns: :class:`AtpBlock` """ assert seq > 0 encoded = dag_cbor.encode(data) digest = multihash.digest(encoded, 'sha2-256') cid = CID('base58btc', 1, 'dag-cbor', digest) repo_key = ndb.Key(AtpRepo, repo_did) atp_block = AtpBlock.get_or_insert(cid.encode('base32'), repo=repo_key, encoded=encoded, seq=seq) assert atp_block.seq <= seq return atp_block
[docs] def to_block(self): """Converts to :class:`Block`. Returns: Block """ ops = [storage.CommitOp(action=Action[op.action.upper()], path=op.path, cid=CID.decode(op.cid) if op.cid else None) for op in self.ops] return Block(cid=self.cid, encoded=self.encoded, seq=self.seq, ops=ops, time=self.created)
[docs] @classmethod def from_block(cls, *, repo_did, block): """Converts a :class:`Block` to an :class:`AtpBlock`. Args: repo_did (str) block (Block) Returns: AtpBlock """ ops = [CommitOp(action=op.action.name.lower(), path=op.path, cid=op.cid.encode('base32') if op.cid else None) for op in (block.ops or [])] created = block.time.astimezone(timezone.utc).replace(tzinfo=None) return AtpBlock(id=block.cid.encode('base32'), encoded=block.encoded, repo=ndb.Key(AtpRepo, repo_did), seq=block.seq, ops=ops, created=created)
[docs]class AtpSequence(ndb.Model): """A sequence number for a given event stream NSID. Sequence numbers are monotonically increasing, without gaps (which ATProto doesn't require), starting at 1. Background: https://atproto.com/specs/event-stream#sequence-numbers Key name is XRPC method NSID. At first, I considered using datastore allocated ids for sequence numbers, but they're not guaranteed to be monotonically increasing, so I switched to this. """ next = ndb.IntegerProperty(required=True) created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod @ndb.transactional() def allocate(cls, nsid): """Returns the next sequence number for a given NSID. Creates a new :class:`AtpSequence` entity if one doesn't already exist for the given NSID. Args: nsid (str): the subscription XRPC method for this sequence number Returns: integer, next sequence number for this NSID """ seq = AtpSequence.get_or_insert(nsid, next=1) ret = seq.next seq.next += 1 seq.put() return ret
[docs] @classmethod def last(cls, nsid): """Returns the last sequence number for a given NSID. Creates a new :class:`AtpSequence` entity if one doesn't already exist for the given NSID. Args: nsid (str): the subscription XRPC method for this sequence number Returns: integer, last sequence number for this NSID """ seq = AtpSequence.get_or_insert(nsid, next=1) return seq.next - 1
[docs]class AtpRemoteBlob(ndb.Model): """A blob available at a public HTTP URL that we don't store ourselves. Key ID is the URL. * https://atproto.com/specs/data-model#blob-type * https://atproto.com/specs/xrpc#blob-upload-and-download TODO: * follow redirects, use final URL as key id * abstract this in :class:`Storage` """ cid = ndb.StringProperty(required=True) size = ndb.IntegerProperty(required=True) mime_type = ndb.StringProperty(required=True, default='application/octet-stream') created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod @ndb.transactional() def get_or_create(cls, *, url=None, get_fn=requests.get): """Returns a new or existing :class:`AtpRemoteBlob` for a given URL. If there isn't an existing :class:`AtpRemoteBlob`, fetches the URL over the network and creates a new one for it. Args: url (str) get_fn (callable): for making HTTP GET requests Returns: AtpRemoteBlob: existing or newly created :class:`AtpRemoteBlob` Raises: requests.RequestException: if the HTTP request to fetch the blob failed """ assert url existing = cls.get_by_id(url) if existing: return existing resp = get_fn(url) resp.raise_for_status() mime_type = resp.headers.get('Content-Type') if not mime_type: mime_type, _ = mimetypes.guess_type(url) digest = multihash.digest(resp.content, 'sha2-256') cid = CID('base58btc', 1, 'raw', digest).encode('base32') logger.info(f'Creating new AtpRemoteBlob for {url} CID {cid}') mime_type_prop = {'mime_type': mime_type} if mime_type else {} blob = cls(id=url, cid=cid, size=len(resp.content), **mime_type_prop) blob.put() return blob
[docs] def as_object(self): """Returns an ATProto ``blob`` object for this blob. https://atproto.com/specs/data-model#blob-type Returns: dict: """ return { '$type': 'blob', 'ref': CID.decode(self.cid), 'mimeType': self.mime_type, 'size': self.size, }
[docs]class DatastoreStorage(Storage): """Google Cloud Datastore implementation of :class:`Storage`. Sequence numbers in :class:`AtpBlock` are allocated per commit; all blocks in a given commit will have the same sequence number. They're currently sequential counters, starting at 1, stored in an :class:`AtpSequence` entity. See :class:`Storage` for method details. """ ndb_client = None
[docs] def __init__(self, *, ndb_client=None): """Constructor. Args: ndb_client (google.cloud.ndb.Client): used in :meth:`read_blocks_by_seq`; it's used in the `subscribeRepos` event subscription, so lexrpc calls it on a different thread, so it needs its own ndb client context. """ super().__init__() self.ndb_client = ndb_client
def ndb_context(fn): @wraps(fn) def decorated(self, *args, **kwargs): context = get_context(raise_context_error=False) with context.use() if context else self.ndb_client.context(): ret = fn(self, *args, **kwargs) return ret return decorated
[docs] @ndb_context def create_repo(self, repo, *, signing_key, rotation_key=None): assert repo.did assert repo.head handles = [repo.handle] if repo.handle else [] signing_key_pem = signing_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) rotation_key_pem = None if rotation_key: rotation_key_pem = rotation_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) atp_repo = AtpRepo(id=repo.did, handles=handles, head=repo.head.cid.encode('base32'), signing_key_pem=signing_key_pem, rotation_key_pem=rotation_key_pem) atp_repo.put() logger.info(f'Stored repo {atp_repo}')
[docs] @ndb_context def load_repo(self, did_or_handle): assert did_or_handle atp_repo = (AtpRepo.get_by_id(did_or_handle) or AtpRepo.query(AtpRepo.handles == did_or_handle).get()) if not atp_repo: logger.info(f"Couldn't find repo for {did_or_handle}") return None elif atp_repo.status == TOMBSTONED: raise TombstonedRepo(f'{atp_repo.key} is tombstoned') logger.info(f'Loading repo {atp_repo.key}') self.head = CID.decode(atp_repo.head) handle = atp_repo.handles[0] if atp_repo.handles else None return Repo.load(self, cid=self.head, handle=handle, signing_key=atp_repo.signing_key, rotation_key=atp_repo.rotation_key)
[docs] @ndb_context def load_repos(self, after=None, limit=500): query = AtpRepo.query() if after: query = query.filter(AtpRepo.key > AtpRepo(id=after).key) # duplicates parts of Repo.load but batches reading blocks from storage atp_repos = query.fetch(limit=limit) cids = [CID.decode(r.head) for r in atp_repos] blocks = self.read_many(cids) # dict mapping CID to block heads = [blocks[cid] for cid in cids] # MST.load doesn't read from storage msts = [MST.load(storage=self, cid=block.decoded['data']) for block in heads] return [Repo(storage=self, mst=mst, head=head, status=atp_repo.status, handle=atp_repo.handles[0] if atp_repo.handles else None, signing_key=atp_repo.signing_key, rotation_key=atp_repo.rotation_key) for atp_repo, head, mst in zip(atp_repos, heads, msts)]
@ndb_context def _tombstone_repo(self, repo): @ndb.transactional() def update(): atp_repo = AtpRepo.get_by_id(repo.did) atp_repo.status = TOMBSTONED atp_repo.put() update()
[docs] @ndb_context def read(self, cid): block = AtpBlock.get_by_id(cid.encode('base32')) if block: return block.to_block()
[docs] @ndb_context def read_many(self, cids): keys = [ndb.Key(AtpBlock, cid.encode('base32')) for cid in cids] got = list(zip(cids, ndb.get_multi(keys))) return {cid: block.to_block() if block else None for cid, block in got}
# can't use @ndb_context because this is a generator, not a normal function
[docs] def read_blocks_by_seq(self, start=0): assert start >= 0 context = get_context(raise_context_error=False) with context.use() if context else self.ndb_client.context() as cm: # lexrpc event subscription handlers like subscribeRepos call this # on a different thread, so if we're there, we need to create a new # ndb context try: for atp_block in AtpBlock.query(AtpBlock.seq >= start)\ .order(AtpBlock.seq): yield atp_block.to_block() except ContextError as e: logging.warning(f'lost ndb context! client may have disconnected? "{e}"') return
[docs] @ndb_context def has(self, cid): return self.read(cid) is not None
[docs] @ndb_context def write(self, repo_did, obj, seq=None): if seq is None: seq = self.allocate_seq(SUBSCRIBE_REPOS_NSID) return AtpBlock.create(repo_did=repo_did, data=obj, seq=seq).cid
[docs] @ndb_context @ndb.transactional() def apply_commit(self, commit_data): seq = tid_to_int(commit_data.commit.decoded['rev']) assert seq for block in commit_data.blocks.values(): template = AtpBlock.from_block( repo_did=commit_data.commit.decoded['did'], block=block) # get_or_insert so we don't wipe out any existing blocks' sequence # numbers. (occasionally we see existing blocks recur, eg MST nodes.) atp_block = AtpBlock.get_or_insert( template.key.id(), repo=template.repo, encoded=block.encoded, seq=seq, ops=template.ops) block.seq = seq self.head = commit_data.commit.cid commit = commit_data.commit.decoded head_encoded = self.head.encode('base32') repo = AtpRepo.get_by_id(commit['did']) if repo: logger.info(f'Updating {repo.key}') repo.head = head_encoded repo.put()
[docs] @ndb_context def allocate_seq(self, nsid): assert nsid return AtpSequence.allocate(nsid)
[docs] @ndb_context def last_seq(self, nsid): assert nsid return AtpSequence.last(nsid)