"""Google Cloud Datastore implementation of repo storage."""
from datetime import timezone
from functools import wraps
import json
import logging
import mimetypes
import requests
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
import dag_cbor
import dag_json
from google.cloud import ndb
from google.cloud.ndb.context import get_context
from google.cloud.ndb.exceptions import ContextError
from multiformats import CID, multicodec, multihash
from .mst import MST
from .repo import Repo
from . import storage
from .storage import Action, Block, Storage, SUBSCRIBE_REPOS_NSID
from .util import dag_cbor_cid, tid_to_int, TOMBSTONED, TombstonedRepo
logger = logging.getLogger(__name__)
[docs]class WriteOnce:
""":class:`ndb.Property` mix-in, prevents changing it once it's set."""
def _set_value(self, entity, value):
existing = self._get_value(entity)
if existing is not None and value != existing:
raise ndb.ReadonlyPropertyError(f"{self._name} can't be changed")
return super()._set_value(entity, value)
[docs]class JsonProperty(ndb.TextProperty):
"""Fork of ndb's that subclasses :class:`ndb.TextProperty` instead of :class:`ndb.BlobProperty`.
This makes values show up as normal, human-readable, serialized JSON in the
web console.
https://github.com/googleapis/python-ndb/issues/874#issuecomment-1442753255
Duplicated in oauth-dropins/webutil:
https://github.com/snarfed/webutil/blob/main/models.py
"""
def _validate(self, value):
if not isinstance(value, dict):
raise TypeError('JSON property must be a dict')
def _to_base_type(self, value):
as_str = json.dumps(value, separators=(',', ':'), ensure_ascii=True)
return as_str.encode('ascii')
def _from_base_type(self, value):
if not isinstance(value, str):
value = value.decode('ascii')
return json.loads(value)
[docs]class ComputedJsonProperty(JsonProperty, ndb.ComputedProperty):
"""Custom :class:`ndb.ComputedProperty` for JSON values that stores them as
strings.
...instead of like :class:`ndb.StructuredProperty`, with "entity" type, which
bloats them unnecessarily in the datastore.
"""
[docs] def __init__(self, *args, **kwargs):
kwargs['indexed'] = False
super().__init__(*args, **kwargs)
[docs]class WriteOnceBlobProperty(WriteOnce, ndb.BlobProperty):
pass
[docs]class CommitOp(ndb.Model):
"""Repo operations - creates, updates, deletes - included in a commit.
Used in a :class:`StructuredProperty` inside :class:`AtpBlock`; not stored
directly in the datastore.
https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty
"""
action = ndb.StringProperty(required=True, choices=('create', 'update', 'delete'))
path = ndb.StringProperty(required=True)
cid = ndb.StringProperty() # unset for deletes
[docs]class AtpRepo(ndb.Model):
r"""An ATProto repo.
Key name is DID. Only stores the repo's metadata. Blocks are stored in
:class:`AtpBlock`\s.
Attributes:
* handles (str): repeated, optional
* head (str): CID
* signing_key (str)
* rotation_key (str)
* status (str)
"""
handles = ndb.StringProperty(repeated=True)
head = ndb.StringProperty(required=True)
# TODO: add password hash?
# these are both secp256k1 private keys, PEM-encoded bytes
# https://atproto.com/specs/cryptography
signing_key_pem = ndb.BlobProperty(required=True)
# TODO: rename this recovery_key_pem?
# https://discord.com/channels/1097580399187738645/1098725036917002302/1153447354003894372
rotation_key_pem = ndb.BlobProperty()
status = ndb.StringProperty(choices=(TOMBSTONED,))
created = ndb.DateTimeProperty(auto_now_add=True)
updated = ndb.DateTimeProperty(auto_now=True)
@property
def signing_key(self):
"""(ec.EllipticCurvePrivateKey)"""
return serialization.load_pem_private_key(self.signing_key_pem,
password=None)
@property
def rotation_key(self):
"""(ec.EllipticCurvePrivateKey` or None)"""
if self.rotation_key_pem:
return serialization.load_pem_private_key(self.rotation_key_pem,
password=None)
[docs]class AtpBlock(ndb.Model):
"""A data record, MST node, repo commit, or other event.
Key name is the DAG-CBOR base32 CID of the data.
Events should have a fully-qualified ``$type`` field that's one of the
``message`` types in ``com.atproto.sync.subscribeRepos``, eg
``com.atproto.sync.subscribeRepos#tombstone``.
Properties:
* encoded (bytes): DAG-CBOR encoded value
* data (dict): DAG-JSON value, only used for human debugging
* seq (int): sequence number for the subscribeRepos event stream
"""
repo = ndb.KeyProperty(AtpRepo, required=True)
encoded = WriteOnceBlobProperty(required=True)
seq = ndb.IntegerProperty(required=True)
ops = ndb.StructuredProperty(CommitOp, repeated=True)
created = ndb.DateTimeProperty(auto_now_add=True)
@ComputedJsonProperty
def decoded(self):
return json.loads(dag_json.encode(dag_cbor.decode(self.encoded)))
@property
def cid(self):
return CID.decode(self.key.id())
[docs] @staticmethod
def create(*, repo_did, data, seq):
"""Writes a new AtpBlock to the datastore.
If the block already exists in the datastore, leave it untouched.
Notably, leave its sequence number as is, since it will be lower than
this current sequence number.
Args:
repo_did (str):
data (dict): value
seq (int):
Returns:
:class:`AtpBlock`
"""
assert seq > 0
encoded = dag_cbor.encode(data)
digest = multihash.digest(encoded, 'sha2-256')
cid = CID('base58btc', 1, 'dag-cbor', digest)
repo_key = ndb.Key(AtpRepo, repo_did)
atp_block = AtpBlock.get_or_insert(cid.encode('base32'), repo=repo_key,
encoded=encoded, seq=seq)
assert atp_block.seq <= seq
return atp_block
[docs] def to_block(self):
"""Converts to :class:`Block`.
Returns:
Block
"""
ops = [storage.CommitOp(action=Action[op.action.upper()], path=op.path,
cid=CID.decode(op.cid) if op.cid else None)
for op in self.ops]
return Block(cid=self.cid, encoded=self.encoded, seq=self.seq, ops=ops,
time=self.created)
[docs] @classmethod
def from_block(cls, *, repo_did, block):
"""Converts a :class:`Block` to an :class:`AtpBlock`.
Args:
repo_did (str)
block (Block)
Returns:
AtpBlock
"""
ops = [CommitOp(action=op.action.name.lower(), path=op.path,
cid=op.cid.encode('base32') if op.cid else None)
for op in (block.ops or [])]
created = block.time.astimezone(timezone.utc).replace(tzinfo=None)
return AtpBlock(id=block.cid.encode('base32'), encoded=block.encoded,
repo=ndb.Key(AtpRepo, repo_did), seq=block.seq, ops=ops,
created=created)
[docs]class AtpSequence(ndb.Model):
"""A sequence number for a given event stream NSID.
Sequence numbers are monotonically increasing, without gaps (which ATProto
doesn't require), starting at 1. Background:
https://atproto.com/specs/event-stream#sequence-numbers
Key name is XRPC method NSID.
At first, I considered using datastore allocated ids for sequence numbers,
but they're not guaranteed to be monotonically increasing, so I switched to
this.
"""
next = ndb.IntegerProperty(required=True)
created = ndb.DateTimeProperty(auto_now_add=True)
updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod
@ndb.transactional()
def allocate(cls, nsid):
"""Returns the next sequence number for a given NSID.
Creates a new :class:`AtpSequence` entity if one doesn't already exist
for the given NSID.
Args:
nsid (str): the subscription XRPC method for this sequence number
Returns:
integer, next sequence number for this NSID
"""
seq = AtpSequence.get_or_insert(nsid, next=1)
ret = seq.next
seq.next += 1
seq.put()
return ret
[docs] @classmethod
def last(cls, nsid):
"""Returns the last sequence number for a given NSID.
Creates a new :class:`AtpSequence` entity if one doesn't already exist
for the given NSID.
Args:
nsid (str): the subscription XRPC method for this sequence number
Returns:
integer, last sequence number for this NSID
"""
seq = AtpSequence.get_or_insert(nsid, next=1)
return seq.next - 1
[docs]class AtpRemoteBlob(ndb.Model):
"""A blob available at a public HTTP URL that we don't store ourselves.
Key ID is the URL.
* https://atproto.com/specs/data-model#blob-type
* https://atproto.com/specs/xrpc#blob-upload-and-download
TODO:
* follow redirects, use final URL as key id
* abstract this in :class:`Storage`
"""
cid = ndb.StringProperty(required=True)
size = ndb.IntegerProperty(required=True)
mime_type = ndb.StringProperty(required=True, default='application/octet-stream')
created = ndb.DateTimeProperty(auto_now_add=True)
updated = ndb.DateTimeProperty(auto_now=True)
[docs] @classmethod
@ndb.transactional()
def get_or_create(cls, *, url=None, get_fn=requests.get):
"""Returns a new or existing :class:`AtpRemoteBlob` for a given URL.
If there isn't an existing :class:`AtpRemoteBlob`, fetches the URL over
the network and creates a new one for it.
Args:
url (str)
get_fn (callable): for making HTTP GET requests
Returns:
AtpRemoteBlob: existing or newly created :class:`AtpRemoteBlob`
Raises:
requests.RequestException: if the HTTP request to fetch the blob failed
"""
assert url
existing = cls.get_by_id(url)
if existing:
return existing
resp = get_fn(url)
resp.raise_for_status()
mime_type = resp.headers.get('Content-Type')
if not mime_type:
mime_type, _ = mimetypes.guess_type(url)
digest = multihash.digest(resp.content, 'sha2-256')
cid = CID('base58btc', 1, 'raw', digest).encode('base32')
logger.info(f'Creating new AtpRemoteBlob for {url} CID {cid}')
mime_type_prop = {'mime_type': mime_type} if mime_type else {}
blob = cls(id=url, cid=cid, size=len(resp.content), **mime_type_prop)
blob.put()
return blob
[docs] def as_object(self):
"""Returns an ATProto ``blob`` object for this blob.
https://atproto.com/specs/data-model#blob-type
Returns:
dict:
"""
return {
'$type': 'blob',
'ref': CID.decode(self.cid),
'mimeType': self.mime_type,
'size': self.size,
}
[docs]class DatastoreStorage(Storage):
"""Google Cloud Datastore implementation of :class:`Storage`.
Sequence numbers in :class:`AtpBlock` are allocated per commit; all blocks
in a given commit will have the same sequence number. They're currently
sequential counters, starting at 1, stored in an :class:`AtpSequence` entity.
See :class:`Storage` for method details.
"""
ndb_client = None
[docs] def __init__(self, *, ndb_client=None):
"""Constructor.
Args:
ndb_client (google.cloud.ndb.Client): used in :meth:`read_blocks_by_seq`;
it's used in the `subscribeRepos` event subscription, so lexrpc calls
it on a different thread, so it needs its own ndb client context.
"""
super().__init__()
self.ndb_client = ndb_client
def ndb_context(fn):
@wraps(fn)
def decorated(self, *args, **kwargs):
context = get_context(raise_context_error=False)
with context.use() if context else self.ndb_client.context():
ret = fn(self, *args, **kwargs)
return ret
return decorated
[docs] @ndb_context
def create_repo(self, repo, *, signing_key, rotation_key=None):
assert repo.did
assert repo.head
handles = [repo.handle] if repo.handle else []
signing_key_pem = signing_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
rotation_key_pem = None
if rotation_key:
rotation_key_pem = rotation_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
atp_repo = AtpRepo(id=repo.did, handles=handles,
head=repo.head.cid.encode('base32'),
signing_key_pem=signing_key_pem,
rotation_key_pem=rotation_key_pem)
atp_repo.put()
logger.info(f'Stored repo {atp_repo}')
[docs] @ndb_context
def load_repo(self, did_or_handle):
assert did_or_handle
atp_repo = (AtpRepo.get_by_id(did_or_handle)
or AtpRepo.query(AtpRepo.handles == did_or_handle).get())
if not atp_repo:
logger.info(f"Couldn't find repo for {did_or_handle}")
return None
elif atp_repo.status == TOMBSTONED:
raise TombstonedRepo(f'{atp_repo.key} is tombstoned')
logger.info(f'Loading repo {atp_repo.key}')
self.head = CID.decode(atp_repo.head)
handle = atp_repo.handles[0] if atp_repo.handles else None
return Repo.load(self, cid=self.head, handle=handle,
signing_key=atp_repo.signing_key,
rotation_key=atp_repo.rotation_key)
[docs] @ndb_context
def load_repos(self, after=None, limit=500):
query = AtpRepo.query()
if after:
query = query.filter(AtpRepo.key > AtpRepo(id=after).key)
# duplicates parts of Repo.load but batches reading blocks from storage
atp_repos = query.fetch(limit=limit)
cids = [CID.decode(r.head) for r in atp_repos]
blocks = self.read_many(cids) # dict mapping CID to block
heads = [blocks[cid] for cid in cids]
# MST.load doesn't read from storage
msts = [MST.load(storage=self, cid=block.decoded['data']) for block in heads]
return [Repo(storage=self, mst=mst, head=head, status=atp_repo.status,
handle=atp_repo.handles[0] if atp_repo.handles else None,
signing_key=atp_repo.signing_key,
rotation_key=atp_repo.rotation_key)
for atp_repo, head, mst in zip(atp_repos, heads, msts)]
@ndb_context
def _tombstone_repo(self, repo):
@ndb.transactional()
def update():
atp_repo = AtpRepo.get_by_id(repo.did)
atp_repo.status = TOMBSTONED
atp_repo.put()
update()
[docs] @ndb_context
def read(self, cid):
block = AtpBlock.get_by_id(cid.encode('base32'))
if block:
return block.to_block()
[docs] @ndb_context
def read_many(self, cids):
keys = [ndb.Key(AtpBlock, cid.encode('base32')) for cid in cids]
got = list(zip(cids, ndb.get_multi(keys)))
return {cid: block.to_block() if block else None
for cid, block in got}
# can't use @ndb_context because this is a generator, not a normal function
[docs] def read_blocks_by_seq(self, start=0):
assert start >= 0
context = get_context(raise_context_error=False)
with context.use() if context else self.ndb_client.context() as cm:
# lexrpc event subscription handlers like subscribeRepos call this
# on a different thread, so if we're there, we need to create a new
# ndb context
try:
for atp_block in AtpBlock.query(AtpBlock.seq >= start)\
.order(AtpBlock.seq):
yield atp_block.to_block()
except ContextError as e:
logging.warning(f'lost ndb context! client may have disconnected? "{e}"')
return
[docs] @ndb_context
def has(self, cid):
return self.read(cid) is not None
[docs] @ndb_context
def write(self, repo_did, obj, seq=None):
if seq is None:
seq = self.allocate_seq(SUBSCRIBE_REPOS_NSID)
return AtpBlock.create(repo_did=repo_did, data=obj, seq=seq).cid
[docs] @ndb_context
@ndb.transactional()
def apply_commit(self, commit_data):
seq = tid_to_int(commit_data.commit.decoded['rev'])
assert seq
for block in commit_data.blocks.values():
template = AtpBlock.from_block(
repo_did=commit_data.commit.decoded['did'], block=block)
# get_or_insert so we don't wipe out any existing blocks' sequence
# numbers. (occasionally we see existing blocks recur, eg MST nodes.)
atp_block = AtpBlock.get_or_insert(
template.key.id(), repo=template.repo, encoded=block.encoded,
seq=seq, ops=template.ops)
block.seq = seq
self.head = commit_data.commit.cid
commit = commit_data.commit.decoded
head_encoded = self.head.encode('base32')
repo = AtpRepo.get_by_id(commit['did'])
if repo:
logger.info(f'Updating {repo.key}')
repo.head = head_encoded
repo.put()
[docs] @ndb_context
def allocate_seq(self, nsid):
assert nsid
return AtpSequence.allocate(nsid)
[docs] @ndb_context
def last_seq(self, nsid):
assert nsid
return AtpSequence.last(nsid)