Source code for arroba.datastore_storage

"""Google Cloud Datastore implementation of repo storage."""
import json
import logging
import requests

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
import dag_cbor
import dag_json
from google.cloud import ndb
from multiformats import CID, multicodec, multihash

from .repo import Repo
from . import storage
from .storage import Action, Block, Storage, SUBSCRIBE_REPOS_NSID
from .util import dag_cbor_cid, tid_to_int

logger = logging.getLogger(__name__)


[docs]class WriteOnce: """:class:`ndb.Property` mix-in, prevents changing it once it's set.""" def _set_value(self, entity, value): existing = self._get_value(entity) if existing is not None and value != existing: raise ndb.ReadonlyPropertyError(f"{self._name} can't be changed") return super()._set_value(entity, value)
[docs]class JsonProperty(ndb.TextProperty): """Fork of ndb's that subclasses :class:`ndb.TextProperty` instead of :class:`ndb.BlobProperty`. This makes values show up as normal, human-readable, serialized JSON in the web console. https://github.com/googleapis/python-ndb/issues/874#issuecomment-1442753255 Duplicated in oauth-dropins/webutil: https://github.com/snarfed/webutil/blob/main/models.py """ def _validate(self, value): if not isinstance(value, dict): raise TypeError('JSON property must be a dict') def _to_base_type(self, value): as_str = json.dumps(value, separators=(',', ':'), ensure_ascii=True) return as_str.encode('ascii') def _from_base_type(self, value): if not isinstance(value, str): value = value.decode('ascii') return json.loads(value)
[docs]class ComputedJsonProperty(JsonProperty, ndb.ComputedProperty): """Custom :class:`ndb.ComputedProperty` for JSON values that stores them as strings. ...instead of like :class:`ndb.StructuredProperty`, with "entity" type, which bloats them unnecessarily in the datastore. """
[docs] def __init__(self, *args, **kwargs): kwargs['indexed'] = False super().__init__(*args, **kwargs)
[docs]class WriteOnceBlobProperty(WriteOnce, ndb.BlobProperty): pass
[docs]class CommitOp(ndb.Model): """Repo operations - creates, updates, deletes - included in a commit. Used in a :class:`StructuredProperty` inside :class:`AtpBlock`; not stored directly in the datastore. https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty """ action = ndb.StringProperty(required=True, choices=['create', 'update', 'delete']) path = ndb.StringProperty(required=True) cid = ndb.StringProperty() # unset for deletes
[docs]class AtpRepo(ndb.Model): r"""An ATProto repo. Key name is DID. Only stores the repo's metadata. Blocks are stored in :class:`AtpBlock`\s. Attributes: * handles (str): repeated, optional * head (str): CID * signing_key (str) * rotation_key (str) """ handles = ndb.StringProperty(repeated=True) head = ndb.StringProperty(required=True) # TODO: add password hash? # these are both secp256k1 private keys, PEM-encoded bytes # https://atproto.com/specs/cryptography signing_key_pem = ndb.BlobProperty(required=True) # TODO: rename this recovery_key_pem? # https://discord.com/channels/1097580399187738645/1098725036917002302/1153447354003894372 rotation_key_pem = ndb.BlobProperty() created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True) @property def signing_key(self): """(ec.EllipticCurvePrivateKey)""" return serialization.load_pem_private_key(self.signing_key_pem, password=None) @property def rotation_key(self): """(ec.EllipticCurvePrivateKey` or None)""" if self.rotation_key_pem: return serialization.load_pem_private_key(self.rotation_key_pem, password=None)
[docs]class AtpBlock(ndb.Model): """A data record, MST node, or commit. Key name is the DAG-CBOR base32 CID of the data. Properties: * encoded (bytes): DAG-CBOR encoded value * data (dict): DAG-JSON value, only used for human debugging * seq (int): sequence number for the subscribeRepos event stream """ repo = ndb.KeyProperty(AtpRepo, required=True) encoded = WriteOnceBlobProperty(required=True) seq = ndb.IntegerProperty(required=True) ops = ndb.StructuredProperty(CommitOp, repeated=True) created = ndb.DateTimeProperty(auto_now_add=True) @ComputedJsonProperty def decoded(self): return json.loads(dag_json.encode(dag_cbor.decode(self.encoded))) @property def cid(self): return CID.decode(self.key.id())
[docs] @staticmethod def create(*, repo_did, data, seq): """Writes a new AtpBlock to the datastore. If the block already exists in the datastore, leave it untouched. Notably, leave its sequence number as is, since it will be lower than this current sequence number. Args: repo_did (str): data (dict): value seq (int): Returns: :class:`AtpBlock` """ assert seq > 0 encoded = dag_cbor.encode(data) digest = multihash.digest(encoded, 'sha2-256') cid = CID('base58btc', 1, 'dag-cbor', digest) repo_key = ndb.Key(AtpRepo, repo_did) atp_block = AtpBlock.get_or_insert(cid.encode('base32'), repo=repo_key, encoded=encoded, seq=seq) assert atp_block.seq <= seq return atp_block
[docs] def to_block(self): """Converts to :class:`Block`. Returns: Block """ ops = [storage.CommitOp(action=Action[op.action.upper()], path=op.path, cid=CID.decode(op.cid) if op.cid else None) for op in self.ops] return Block(cid=self.cid, encoded=self.encoded, seq=self.seq, ops=ops)
[docs] @classmethod def from_block(cls, *, repo_did, block): """Converts a :class:`Block` to an :class:`AtpBlock`. Args: repo_did (str) block (Block) Returns: AtpBlock """ ops = [CommitOp(action=op.action.name.lower(), path=op.path, cid=op.cid.encode('base32') if op.cid else None) for op in (block.ops or [])] return AtpBlock(id=block.cid.encode('base32'), encoded=block.encoded, repo=ndb.Key(AtpRepo, repo_did), seq=block.seq, ops=ops)
[docs]class AtpSequence(ndb.Model): """A sequence number for a given event stream NSID. Sequence numbers are monotonically increasing, without gaps (which ATProto doesn't require), starting at 1. Background: https://atproto.com/specs/event-stream#sequence-numbers Key name is XRPC method NSID. At first, I considered using datastore allocated ids for sequence numbers, but they're not guaranteed to be monotonically increasing, so I switched to this. """ next = ndb.IntegerProperty(required=True) created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod @ndb.transactional() def allocate(cls, nsid): """Returns the next sequence number for a given NSID. Creates a new :class:`AtpSequence` entity if one doesn't already exist for the given NSID. Args: nsid (str): the subscription XRPC method for this sequence number Returns: integer, next sequence number for this NSID """ seq = AtpSequence.get_or_insert(nsid, next=1) ret = seq.next seq.next += 1 seq.put() return ret
[docs] @classmethod def last(cls, nsid): """Returns the last sequence number for a given NSID. Creates a new :class:`AtpSequence` entity if one doesn't already exist for the given NSID. Args: nsid (str): the subscription XRPC method for this sequence number Returns: integer, last sequence number for this NSID """ seq = AtpSequence.get_or_insert(nsid, next=1) return seq.next - 1
[docs]class AtpRemoteBlob(ndb.Model): """A blob available at a public HTTP URL that we don't store ourselves. Key ID is the URL. * https://atproto.com/specs/data-model#blob-type * https://atproto.com/specs/xrpc#blob-upload-and-download TODO: * follow redirects, use final URL as key id * abstract this in :class:`Storage` """ cid = ndb.StringProperty(required=True) size = ndb.IntegerProperty(required=True) mime_type = ndb.StringProperty(required=True, default='application/octet-stream') created = ndb.DateTimeProperty(auto_now_add=True) updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod @ndb.transactional() def get_or_create(cls, *, url=None, get_fn=requests.get): """Returns a new or existing :class:`AtpRemoteBlob` for a given URL. If there isn't an existing :class:`AtpRemoteBlob`, fetches the URL over the network and creates a new one for it. Args: url (str) get_fn (callable): for making HTTP GET requests Returns: AtpRemoteBlob: existing or newly created :class:`AtpRemoteBlob` Raises: requests.RequestException: if the HTTP request to fetch the blob failed """ assert url existing = cls.get_by_id(url) if existing: return existing resp = get_fn(url) resp.raise_for_status() digest = multihash.digest(resp.content, 'sha2-256') cid = CID('base58btc', 1, 'raw', digest).encode('base32') logger.info(f'Creating new AtpRemoteBlob for {url} CID {cid}') blob = cls(id=url, cid=cid, mime_type=resp.headers.get('Content-Type'), size=len(resp.content)) blob.put() return blob
[docs] def as_object(self): """Returns an ATProto ``blob`` object for this blob. https://atproto.com/specs/data-model#blob-type Returns: dict: """ return { '$type': 'blob', 'ref': CID.decode(self.cid), 'mimeType': self.mime_type, 'size': self.size, }
[docs]class DatastoreStorage(Storage): """Google Cloud Datastore implementation of :class:`Storage`. Sequence numbers in :class:`AtpBlock` are allocated per commit; all blocks in a given commit will have the same sequence number. They're currently sequential counters, starting at 1, stored in an :class:`AtpSequence` entity. See :class:`Storage` for method details. """
[docs] def create_repo(self, repo, *, signing_key, rotation_key=None): assert repo.did assert repo.head handles = [repo.handle] if repo.handle else [] signing_key_pem = signing_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) rotation_key_pem = None if rotation_key: rotation_key_pem = rotation_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) atp_repo = AtpRepo(id=repo.did, handles=handles, head=repo.head.cid.encode('base32'), signing_key_pem=signing_key_pem, rotation_key_pem=rotation_key_pem) atp_repo.put() logger.info(f'Stored repo {atp_repo}')
[docs] def load_repo(self, did_or_handle): assert did_or_handle atp_repo = (AtpRepo.get_by_id(did_or_handle) or AtpRepo.query(AtpRepo.handles == did_or_handle).get()) if not atp_repo: logger.info(f"Couldn't find repo for {did_or_handle}") return None logger.info(f'Loading repo {atp_repo.key}') self.head = CID.decode(atp_repo.head) handle = atp_repo.handles[0] if atp_repo.handles else None return Repo.load(self, cid=self.head, handle=handle, signing_key=atp_repo.signing_key, rotation_key=atp_repo.rotation_key)
[docs] def read(self, cid): block = AtpBlock.get_by_id(cid.encode('base32')) if block: return block.to_block()
[docs] def read_many(self, cids): keys = [ndb.Key(AtpBlock, cid.encode('base32')) for cid in cids] got = list(zip(cids, ndb.get_multi(keys))) return {cid: block.to_block() if block else None for cid, block in got}
[docs] def read_blocks_by_seq(self, start=0): assert start >= 0 for atp_block in AtpBlock.query(AtpBlock.seq >= start)\ .order(AtpBlock.seq): yield atp_block.to_block()
[docs] def has(self, cid): return self.read(cid) is not None
[docs] def write(self, repo_did, obj): seq = self.allocate_seq(SUBSCRIBE_REPOS_NSID) return AtpBlock.create(repo_did=repo_did, data=obj, seq=seq).cid
[docs] @ndb.transactional() def apply_commit(self, commit_data): seq = tid_to_int(commit_data.commit.decoded['rev']) assert seq for block in commit_data.blocks.values(): template = AtpBlock.from_block( repo_did=commit_data.commit.decoded['did'], block=block) # get_or_insert so we don't wipe out any existing blocks' sequence # numbers. (occasionally we see existing blocks recur, eg MST nodes.) atp_block = AtpBlock.get_or_insert( template.key.id(), repo=template.repo, encoded=block.encoded, seq=seq, ops=template.ops) block.seq = seq self.head = commit_data.commit.cid commit = commit_data.commit.decoded head_encoded = self.head.encode('base32') repo = AtpRepo.get_by_id(commit['did']) if repo: logger.info(f'Updating {repo.key}') repo.head = head_encoded repo.put()
[docs] def allocate_seq(self, nsid): assert nsid return AtpSequence.allocate(nsid)
[docs] def last_seq(self, nsid): assert nsid return AtpSequence.last(nsid)