Source code for arroba.datastore_storage

"""Google Cloud Datastore implementation of repo storage."""
from datetime import timezone
from functools import wraps
from io import BytesIO
import json
import logging
import mimetypes
import requests

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
import dag_cbor
import dag_json
from google.cloud import ndb
from google.cloud.ndb import context
from google.cloud.ndb.exceptions import ContextError
from google.cloud.ndb.key import _MAX_KEYPART_BYTES
from lexrpc import ValidationError
from multiformats import CID, multicodec, multihash
from PIL import Image, ImageFile
from pymediainfo import MediaInfo

from .mst import MST
from .repo import Repo
from .server import server
from . import storage
from .storage import Action, Block, Storage, SUBSCRIBE_REPOS_NSID
from .util import (
    dag_cbor_cid,
    tid_to_int,
    DEACTIVATED,
    DELETED,
    TOMBSTONED,
    InactiveRepo,
)

logger = logging.getLogger(__name__)

# Allow bad .ico files with truncated transparency masks
# https://github.com/python-pillow/Pillow/issues/6507#issuecomment-2199724849
ImageFile.LOAD_TRUNCATED_IMAGES = True


[docs] class WriteOnce: """:class:`ndb.Property` mix-in, prevents changing it once it's set.""" def _set_value(self, entity, value): existing = self._get_value(entity) if existing is not None and value != existing: raise ndb.ReadonlyPropertyError(f"{self._name} can't be changed") return super()._set_value(entity, value)
[docs] class JsonProperty(ndb.TextProperty): """Fork of ndb's that subclasses :class:`ndb.TextProperty` instead of :class:`ndb.BlobProperty`. This makes values show up as normal, human-readable, serialized JSON in the web console. https://github.com/googleapis/python-ndb/issues/874#issuecomment-1442753255 Duplicated in oauth-dropins/webutil: https://github.com/snarfed/webutil/blob/main/models.py """ def _validate(self, value): if not isinstance(value, dict): raise TypeError('JSON property must be a dict') def _to_base_type(self, value): as_str = json.dumps(value, separators=(',', ':'), ensure_ascii=True) return as_str.encode('ascii') def _from_base_type(self, value): if not isinstance(value, str): value = value.decode('ascii') return json.loads(value)
[docs] class ComputedJsonProperty(JsonProperty, ndb.ComputedProperty): """Custom :class:`ndb.ComputedProperty` for JSON values that stores them as strings. ...instead of like :class:`ndb.StructuredProperty`, with "entity" type, which bloats them unnecessarily in the datastore. """ def __init__(self, *args, **kwargs): kwargs['indexed'] = False super().__init__(*args, **kwargs)
class WriteOnceBlobProperty(WriteOnce, ndb.BlobProperty): pass
[docs] class CommitOp(ndb.Model): """Repo operations - creates, updates, deletes - included in a commit. Used in a :class:`StructuredProperty` inside :class:`AtpBlock`; not stored directly in the datastore. https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty """ action = ndb.StringProperty(required=True, choices=('create', 'update', 'delete')) path = ndb.StringProperty(required=True) cid = ndb.StringProperty() # unset for deletes prev_cid = ndb.StringProperty() # unset for creates
[docs] class AtpRepo(ndb.Model): r"""An ATProto repo. Key name is DID. Only stores the repo's metadata. Blocks are stored in :class:`AtpBlock`\s. Attributes: * handles (str): repeated, optional * head (str): CID * signing_key (str) * rotation_key (str) * status (str) """ handles = ndb.StringProperty(repeated=True) head = ndb.StringProperty(required=True) # TODO: add password hash? # these are both secp256k1 private keys, PEM-encoded bytes # https://atproto.com/specs/cryptography signing_key_pem = ndb.BlobProperty(required=True) # TODO: rename this recovery_key_pem? # https://discord.com/channels/1097580399187738645/1098725036917002302/1153447354003894372 rotation_key_pem = ndb.BlobProperty() status = ndb.StringProperty(choices=(DEACTIVATED, DELETED, TOMBSTONED)) created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True) @property def signing_key(self): """(ec.EllipticCurvePrivateKey)""" return serialization.load_pem_private_key(self.signing_key_pem, password=None) @property def rotation_key(self): """(ec.EllipticCurvePrivateKey` or None)""" if self.rotation_key_pem: return serialization.load_pem_private_key(self.rotation_key_pem, password=None)
[docs] class AtpBlock(ndb.Model): """A data record, MST node, repo commit, or other event. Key name is the DAG-CBOR base32 CID of the data. Events should have a fully-qualified ``$type`` field that's one of the ``message`` types in ``com.atproto.sync.subscribeRepos``, eg ``com.atproto.sync.subscribeRepos#tombstone``. Properties: * repo (google.cloud.ndb.Key): DID of the first repo that included this block * encoded (bytes): DAG-CBOR encoded value * data (dict): DAG-JSON value, only used for human debugging * seq (int): sequence number for the subscribeRepos event stream """ repo = ndb.KeyProperty(AtpRepo, required=True) encoded = WriteOnceBlobProperty(required=True) seq = ndb.IntegerProperty(required=True) ops = ndb.StructuredProperty(CommitOp, repeated=True) created = ndb.DateTimeProperty(auto_now_add=True) @property def decoded(self): return json.loads(dag_json.encode(dag_cbor.decode(self.encoded))) @property def cid(self): return CID.decode(self.key.id())
[docs] @staticmethod def create(*, repo_did, data, seq): """Writes a new AtpBlock to the datastore. If the block already exists in the datastore, leave it untouched. Notably, leave its sequence number as is, since it will be lower than this current sequence number. Args: repo_did (str): data (dict): value seq (int): Returns: :class:`AtpBlock` """ assert seq > 0 encoded = dag_cbor.encode(data) digest = multihash.digest(encoded, 'sha2-256') cid = CID('base58btc', 1, 'dag-cbor', digest) repo_key = ndb.Key(AtpRepo, repo_did) atp_block = AtpBlock.get_or_insert(cid.encode('base32'), repo=repo_key, encoded=encoded, seq=seq) assert atp_block.seq <= seq return atp_block
[docs] def to_block(self): """Converts to :class:`Block`. Returns: Block """ ops = [storage.CommitOp(action=Action[op.action.upper()], path=op.path, cid=CID.decode(op.cid) if op.cid else None, prev_cid=CID.decode(op.prev_cid) if op.prev_cid else None) for op in self.ops] return Block(cid=self.cid, encoded=self.encoded, seq=self.seq, ops=ops, time=self.created, repo=self.repo)
[docs] @classmethod def from_block(cls, *, repo_did, block): """Converts a :class:`Block` to an :class:`AtpBlock`. Args: repo_did (str) block (Block) Returns: AtpBlock """ ops = [CommitOp(action=op.action.name.lower(), path=op.path, cid=op.cid.encode('base32') if op.cid else None, prev_cid=op.prev_cid.encode('base32') if op.prev_cid else None) for op in (block.ops or [])] created = block.time.astimezone(timezone.utc).replace(tzinfo=None) return AtpBlock(id=block.cid.encode('base32'), encoded=block.encoded, repo=ndb.Key(AtpRepo, repo_did), seq=block.seq, ops=ops, created=created)
[docs] class AtpSequence(ndb.Model): """A sequence number for a given event stream NSID. Sequence numbers are monotonically increasing, without gaps (which ATProto doesn't require), starting at 1. Background: https://atproto.com/specs/event-stream#sequence-numbers Key name is XRPC method NSID. At first, I considered using datastore allocated ids for sequence numbers, but they're not guaranteed to be monotonically increasing, so I switched to this. """ next = ndb.IntegerProperty(required=True) created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod @ndb.transactional(retries=10) def allocate(cls, nsid): """Returns the next sequence number for a given NSID. Creates a new :class:`AtpSequence` entity if one doesn't already exist for the given NSID. Args: nsid (str): the subscription XRPC method for this sequence number Returns: integer, next sequence number for this NSID """ seq = AtpSequence.get_or_insert(nsid, next=1) ret = seq.next seq.next += 1 seq.put() return ret
[docs] @classmethod def last(cls, nsid): """Returns the last sequence number for a given NSID. Creates a new :class:`AtpSequence` entity if one doesn't already exist for the given NSID. Args: nsid (str): the subscription XRPC method for this sequence number Returns: integer, last sequence number for this NSID """ seq = AtpSequence.get_or_insert(nsid, next=1) return seq.next - 1
[docs] class AtpRemoteBlob(ndb.Model): """A blob available at a public HTTP URL that we don't store ourselves. Key ID is the URL. * https://atproto.com/specs/data-model#blob-type * https://atproto.com/specs/xrpc#blob-upload-and-download TODO: * follow redirects, use final URL as key id * abstract this in :class:`Storage` """ cid = ndb.StringProperty(required=True) size = ndb.IntegerProperty(required=True) mime_type = ndb.StringProperty(required=True, default='application/octet-stream') # only populated if mime_type is image/* or video/* # used in images.aspectRatio in app.bsky.embed.images # and aspectRatio in app.bsky.embed.video width = ndb.IntegerProperty() height = ndb.IntegerProperty() # only populated if mime_type is video/* # used to enforce maximum duration duration = ndb.IntegerProperty() created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod def get_or_create(cls, *, url=None, get_fn=requests.get, max_size=None, accept_types=None, name=''): """Returns a new or existing :class:`AtpRemoteBlob` for a given URL. If there isn't an existing :class:`AtpRemoteBlob`, fetches the URL over the network and creates a new one for it. Args: url (str) get_fn (callable): for making HTTP GET requests max_size (int, optional): the ``maxSize`` parameter for this blob field in its lexicon, if any accept_types (sequence of str, optional): the ``accept`` parameter for this blob field in its lexicon, if any. The set of allowed MIME types. name (str, optional): blob field name in lexicon Returns: AtpRemoteBlob: existing or newly created blob Raises: requests.RequestException: if the HTTP request to fetch the blob failed lexrpc.ValidationError: if the blob is over ``max_size``, its type is not in ``accept_types`` or it is a video with a duration above the 3m limit """ def validate_size(size): if max_size and size > max_size: raise ValidationError(f'{url} Content-Length {size} is over {name} blob maxSize {max_size}') def validate_duration(duration): # enforce 3m maximum video duration # https://bsky.app/profile/bsky.app/post/3lk26lxn6sk2u max_duration = 3 * 60_000 # milliseconds if duration and duration > max_duration: raise ValidationError(f'{url} duration {duration / 1000} is over {max_duration / 1000}s') assert url url_key = url if len(url_key) > _MAX_KEYPART_BYTES: # TODO: handle Unicode chars. naive approach is to UTF-8 encode, # truncate, then decode, but that might cut mid character. easier to just # hope/assume the URL is already URL-encoded. url_key = url[:_MAX_KEYPART_BYTES] logger.warning(f'Truncating URL to {_MAX_KEYPART_BYTES} chars: {url_key}') blob = cls.get_by_id(url_key) if blob: validate_size(blob.size) validate_duration(blob.duration) server.validate_mime_type(blob.mime_type, accept_types, name=url) return blob resp = get_fn(url, stream=True) resp.raise_for_status() mime_type = resp.headers.get('Content-Type') if not mime_type: mime_type, _ = mimetypes.guess_type(url) length = resp.headers.get('Content-Length') logger.info(f'Got {resp.status_code} {mime_type} {length} bytes {resp.url}') # check type server.validate_mime_type(mime_type, accept_types, name=url) # check size try: length = int(length) except (TypeError, ValueError): length = None # read body and check length manually below if length: validate_size(length) # now ready to fetch body digest = multihash.digest(resp.content, 'sha2-256') cid = CID('base58btc', 1, 'raw', digest).encode('base32') # note that if the initial URL redirects, we still store it in the # AtpRemoteBlob, not the final resolved URL after redirects. logger.info(f'Creating new AtpRemoteBlob for {url} CID {cid}') blob = cls(id=url_key, cid=cid, size=len(resp.content)) if mime_type: blob.mime_type = mime_type blob.generate_metadata(resp.content) blob.put() # re-validate size in case the server didn't give us Content-Length. # do this after storing blob so that we don't re-download it next time. validate_size(len(resp.content)) validate_duration(blob.duration) return blob
[docs] def as_object(self): """Returns an ATProto ``blob`` object for this blob. https://atproto.com/specs/data-model#blob-type Returns: dict: with ``$type: blob`` and ``ref``, ``mimeType``, and ``size`` fields """ return { '$type': 'blob', 'ref': CID.decode(self.cid), 'mimeType': self.mime_type, 'size': self.size, }
[docs] def generate_metadata(self, content): """Extracts and stores metadata from an image or video. Uses ``self.mime_type`` to determine whether/how to parse the content. Args: content (bytes) """ if not self.mime_type: return if self.mime_type.startswith('image/'): try: with Image.open(BytesIO(content)) as image: self.width, self.height = image.size except (OSError, RuntimeError, Image.DecompressionBombError) as e: logger.info(e) elif self.mime_type.startswith('video/'): try: media_info = MediaInfo.parse(BytesIO(content)) if len(media_info.video_tracks) == 1: track = media_info.video_tracks[0] self.width = track.width self.height = track.height self.duration = track.duration except (OSError, RuntimeError) as e: logger.info(e)
[docs] class DatastoreStorage(Storage): """Google Cloud Datastore implementation of :class:`Storage`. Sequence numbers in :class:`AtpBlock` are allocated per commit; all blocks in a given commit will have the same sequence number. They're currently sequential counters, starting at 1, stored in an :class:`AtpSequence` entity. See :class:`Storage` for method details. """ ndb_client = None ndb_context_kwargs = None def __init__(self, *, ndb_client=None, ndb_context_kwargs=None): """Constructor. Args: ndb_client (google.cloud.ndb.Client): used when there isn't already an ndb context active ndb_context_kwargs (dict): optional, used when creating a new ndb context """ super().__init__() self.ndb_client = ndb_client self.ndb_context_kwargs = ndb_context_kwargs or {} def ndb_context(fn): @wraps(fn) def decorated(self, *args, **kwargs): ctx = context.get_context(raise_context_error=False) with ctx.use() if ctx else self.ndb_client.context(**self.ndb_context_kwargs): ret = fn(self, *args, **kwargs) return ret return decorated @ndb_context def create_repo(self, repo): assert repo.did assert repo.head handles = [repo.handle] if repo.handle else [] signing_key_pem = repo.signing_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) rotation_key_pem = None if repo.rotation_key: rotation_key_pem = repo.rotation_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) atp_repo = AtpRepo(id=repo.did, handles=handles, head=repo.head.cid.encode('base32'), signing_key_pem=signing_key_pem, rotation_key_pem=rotation_key_pem, status=repo.status) atp_repo.put() logger.info(f'Stored repo {atp_repo}') @ndb_context def load_repo(self, did_or_handle): assert did_or_handle atp_repo = (AtpRepo.get_by_id(did_or_handle) or AtpRepo.query(AtpRepo.handles == did_or_handle).get()) if not atp_repo: logger.info(f"Couldn't find repo for {did_or_handle}") return None logger.info(f'Loading repo {atp_repo.key}') self.head = CID.decode(atp_repo.head) handle = atp_repo.handles[0] if atp_repo.handles else None return Repo.load(self, cid=self.head, handle=handle, status=atp_repo.status, signing_key=atp_repo.signing_key, rotation_key=atp_repo.rotation_key) @ndb_context def load_repos(self, after=None, limit=500): query = AtpRepo.query() if after: query = query.filter(AtpRepo.key > AtpRepo(id=after).key) # duplicates parts of Repo.load but batches reading blocks from storage atp_repos = query.fetch(limit=limit) cids = [CID.decode(r.head) for r in atp_repos] blocks = self.read_many(cids) # dict mapping CID to block heads = [blocks[cid] for cid in cids] # MST.load doesn't read from storage msts = [MST.load(storage=self, cid=block.decoded['data']) for block in heads] return [Repo(storage=self, mst=mst, head=head, status=atp_repo.status, handle=atp_repo.handles[0] if atp_repo.handles else None, signing_key=atp_repo.signing_key, rotation_key=atp_repo.rotation_key) for atp_repo, head, mst in zip(atp_repos, heads, msts)] @ndb_context def _set_repo_status(self, repo, status): assert status in (DEACTIVATED, DELETED, TOMBSTONED, None) repo.status = status # in memory only @ndb.transactional() def update(): atp_repo = AtpRepo.get_by_id(repo.did) atp_repo.status = status atp_repo.put() update() def store_repo(self, repo): @ndb.transactional() def store(): atp_repo = AtpRepo.get_by_id(repo.did) atp_repo.populate( handles=[repo.handle] if repo.handle else [], status=repo.status, ) atp_repo.put() logger.info(f'Stored repo {atp_repo}') store() @ndb_context def read(self, cid): block = AtpBlock.get_by_id(cid.encode('base32')) if block: return block.to_block() @ndb_context def read_many(self, cids): keys = [ndb.Key(AtpBlock, cid.encode('base32')) for cid in cids] got = list(zip(cids, ndb.get_multi(keys))) return {cid: block.to_block() if block else None for cid, block in got} # can't use @ndb_context because this is a generator, not a normal function def read_blocks_by_seq(self, start=0, repo=None): assert start >= 0 cur_seq = start cur_seq_cids = [] while True: # lexrpc event subscription handlers like subscribeRepos call this # on a different thread, so if we're there, we need to create a new # ndb context ctx = context.get_context(raise_context_error=False) with (ctx.use() if ctx else self.ndb_client.context(**self.ndb_context_kwargs)): try: query = AtpBlock.query(AtpBlock.seq >= cur_seq).order(AtpBlock.seq) if repo: query = query.filter(AtpBlock.repo == AtpRepo(id=repo).key) # unproven hypothesis: need strong consistency to make sure we # get all blocks for a given seq, including commit # https://console.cloud.google.com/errors/detail/CO2g4eLG_tOkZg;service=atproto-hub;time=P1D;refresh=true;locations=global?project=bridgy-federated for atp_block in query.iter(read_consistency=ndb.STRONG): if atp_block.seq != cur_seq: cur_seq = atp_block.seq cur_seq_cids = [] if atp_block.key.id() not in cur_seq_cids: cur_seq_cids.append(atp_block.key.id()) yield atp_block.to_block() # finished cleanly break except ContextError as e: logging.warning(f'lost ndb context! re-querying at {cur_seq}. {e}') # continue loop, restart query # Context.use() resets this to the previous context when it exits, # but that context is bad now, so make sure we get a new one at the # top of the loop context._state.context = None @ndb_context def has(self, cid): return self.read(cid) is not None @ndb_context def write(self, repo_did, obj, seq=None): if seq is None: seq = self.allocate_seq(SUBSCRIBE_REPOS_NSID) return AtpBlock.create(repo_did=repo_did, data=obj, seq=seq).to_block() @ndb_context def write_blocks(self, blocks): ndb.put_multi(AtpBlock.from_block(repo_did=b.repo, block=b) for b in blocks) @ndb_context # retry aggressively because repo writes can be bursty and cause high # contention. (ndb does exponential backoff.) # https://console.cloud.google.com/errors/detail/CKbL5KSX98uZHw;time=P1D;locations=global?project=bridgy-federated @ndb.transactional(retries=10) def apply_commit(self, commit_data): commit = commit_data.commit.decoded if repo := AtpRepo.get_by_id(commit['did']): if repo.status: raise InactiveRepo(repo.key.id(), repo.status) seq = tid_to_int(commit_data.commit.decoded['rev']) assert seq for block in commit_data.blocks.values(): template = AtpBlock.from_block( repo_did=commit_data.commit.decoded['did'], block=block) # get_or_insert so we don't wipe out any existing blocks' sequence # numbers. (occasionally we see existing blocks recur, eg MST nodes.) atp_block = AtpBlock.get_or_insert( template.key.id(), repo=template.repo, encoded=block.encoded, seq=seq, ops=template.ops) block.seq = seq self.head = commit_data.commit.cid head_encoded = self.head.encode('base32') if repo: logger.info(f'Updating {repo.key}') repo.head = head_encoded repo.put() @ndb_context def allocate_seq(self, nsid): assert nsid return AtpSequence.allocate(nsid) @ndb_context def last_seq(self, nsid): assert nsid return AtpSequence.last(nsid)