"""Google Cloud Datastore implementation of repo storage."""
from datetime import timedelta, timezone
from functools import wraps
from io import BytesIO
import itertools
import json
import logging
import mimetypes
import os
import requests
import threading
import time
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import ec
from google.api_core import gapic_v1
import dag_cbor
import dag_json
import libipld
from google.cloud import ndb
from google.cloud.ndb import context
from google.cloud.ndb.exceptions import ContextError
from google.cloud.ndb.key import _MAX_KEYPART_BYTES
from google.cloud.datastore_v1.types import datastore as ds_pb2
from google.cloud.datastore_v1.types import entity as entity_pb2
from lexrpc import ValidationError
from multiformats import CID, multicodec, multihash
from pymediainfo import MediaInfo
from webutil.models import EncryptedProperty, WriteOnceBlobProperty
import webutil.util
from .mst import MST
from .repo import Repo
from .server import server
from . import storage
from .storage import Action, Block, Storage, SUBSCRIBE_REPOS_NSID
from . import util
from .util import (
dag_cbor_cid,
tid_to_int,
DEACTIVATED,
DELETED,
TOMBSTONED,
InactiveRepo,
)
logger = logging.getLogger(__name__)
BLOB_REFETCH_AGE = timedelta(days=float(os.environ.get('BLOB_REFETCH_DAYS', 7)))
BLOB_REFETCH_TYPES = tuple(os.environ.get('BLOB_REFETCH_TYPES', 'image').split(','))
# https://github.com/bluesky-social/social-app/blob/8ac63d780d38c14f0963859dec5d123836adb913/src/lib/constants.ts#L191
BLOB_MAX_BYTES = int(os.environ.get('BLOB_MAX_BYTES', 100_000_000))
# https://bsky.app/profile/bsky.app/post/3lk26lxn6sk2u
VIDEO_MAX_DURATION = timedelta(minutes=3)
MEMCACHE_SEQUENCE_BATCH = int(os.environ.get('MEMCACHE_SEQUENCE_BATCH', 1000))
MEMCACHE_SEQUENCE_BUFFER = int(os.environ.get('MEMCACHE_SEQUENCE_BUFFER', 100))
# https://github.com/snarfed/bridgy-fed/issues/2367#issuecomment-3969792063
QUERY_TIMEOUT = timedelta(seconds=30)
[docs]
class CommitOp(ndb.Model):
"""Repo operations - creates, updates, deletes - included in a commit.
Used in a :class:`StructuredProperty` inside :class:`AtpBlock`; not stored
directly in the datastore.
https://googleapis.dev/python/python-ndb/latest/model.html#google.cloud.ndb.model.StructuredProperty
"""
action = ndb.StringProperty(required=True, choices=('create', 'update', 'delete'))
path = ndb.StringProperty(required=True)
cid = ndb.StringProperty() # unset for deletes
prev_cid = ndb.StringProperty() # unset for creates
[docs]
class AtpRepo(ndb.Model):
r"""An ATProto repo.
Key name is DID. Only stores the repo's metadata. Blocks are stored in
:class:`AtpBlock`\s.
Attributes:
* handles (str): repeated, optional
* head (str): CID
* signing_key (str)
* rotation_key (str)
* status (str)
"""
handles = ndb.StringProperty(repeated=True)
head = ndb.StringProperty(required=True)
# these are both secp256k1 private keys, PEM-encoded bytes
# https://atproto.com/specs/cryptography
encrypted_signing_key = EncryptedProperty(required=True)
encrypted_rotation_key = EncryptedProperty()
status = ndb.StringProperty(choices=(DEACTIVATED, DELETED, TOMBSTONED))
created = ndb.DateTimeProperty(auto_now_add=True)
updated = ndb.DateTimeProperty(auto_now=True)
# OLD. Do not reuse.
# signing_key_pem = ndb.BlobProperty()
# rotation_key_pem = ndb.BlobProperty()
@property
def signing_key(self):
"""
Returns:
ec.EllipticCurvePrivateKey:
"""
return serialization.load_pem_private_key(self.encrypted_signing_key,
password=None)
@property
def rotation_key(self):
"""
Returns:
ec.EllipticCurvePrivateKey or None:
"""
if self.encrypted_rotation_key:
return serialization.load_pem_private_key(self.encrypted_rotation_key,
password=None)
[docs]
class AtpBlock(ndb.Model):
"""A data record, MST node, repo commit, or other event.
Key name is the DAG-CBOR base32 CID of the data.
Events should have a fully-qualified ``$type`` field that's one of the
``message`` types in ``com.atproto.sync.subscribeRepos``, eg
``com.atproto.sync.subscribeRepos#tombstone``.
Properties:
* repo (google.cloud.ndb.Key): DID of the first repo that included this block
* encoded (bytes): DAG-CBOR encoded value
* seq (int): sequence number for the subscribeRepos event stream
* ops (CommitOps): operations if this is a commit
"""
repo = ndb.KeyProperty(AtpRepo, required=True)
encoded = WriteOnceBlobProperty(required=True)
seq = ndb.IntegerProperty(required=True)
ops = ndb.StructuredProperty(CommitOp, repeated=True)
created = ndb.DateTimeProperty(auto_now_add=True)
@property
def decoded(self):
return dag_cbor.decode(self.encoded)
@property
def cid(self):
return CID.decode(self.key.id())
[docs]
@staticmethod
def create(*, repo_did, data, seq):
"""Writes a new AtpBlock to the datastore.
If the block already exists in the datastore, leave it untouched.
Notably, leave its sequence number as is, since it will be lower than
this current sequence number.
Args:
repo_did (str):
data (dict): value
seq (int):
Returns:
:class:`AtpBlock`
"""
assert seq > 0
encoded = dag_cbor.encode(data)
digest = multihash.digest(encoded, 'sha2-256')
cid = CID('base58btc', 1, 'dag-cbor', digest)
repo_key = ndb.Key(AtpRepo, repo_did)
atp_block = AtpBlock.get_or_insert(cid.encode('base32'), repo=repo_key,
encoded=encoded, seq=seq)
assert atp_block.seq <= seq
return atp_block
[docs]
def to_block(self):
"""Converts to :class:`Block`.
Returns:
Block
"""
ops = [storage.CommitOp(action=Action[op.action.upper()], path=op.path,
cid=CID.decode(op.cid) if op.cid else None,
prev_cid=CID.decode(op.prev_cid) if op.prev_cid else None)
for op in self.ops]
return Block(cid=self.cid, encoded=self.encoded, seq=self.seq, ops=ops,
time=self.created, repo=self.repo)
[docs]
@classmethod
def from_block(cls, block):
"""Converts a :class:`Block` to an :class:`AtpBlock`.
Args:
block (Block)
Returns:
AtpBlock:
"""
ops = [CommitOp(action=op.action.name.lower(), path=op.path,
cid=op.cid.encode('base32') if op.cid else None,
prev_cid=op.prev_cid.encode('base32') if op.prev_cid else None)
for op in (block.ops or [])]
created = block.time.astimezone(timezone.utc).replace(tzinfo=None)
repo_key = ndb.Key(AtpRepo, block.repo) if block.repo else None
return AtpBlock(id=block.cid.encode('base32'), encoded=block.encoded,
repo=repo_key, seq=block.seq, ops=ops, created=created)
[docs]
class AtpSequence(ndb.Model):
"""A sequence number for a given event stream NSID.
Sequence numbers are monotonically increasing, without gaps (which ATProto
doesn't require), starting at 1.
Key name is XRPC method NSID.
At first, I considered using datastore allocated ids for sequence numbers,
but they're not guaranteed to be monotonically increasing, so I switched to
this.
"""
next = ndb.IntegerProperty(required=True)
created = ndb.DateTimeProperty(auto_now_add=True)
updated = ndb.DateTimeProperty(auto_now=True)
[docs]
class NdbMixin:
"""Mixin class that supports the :meth:`ndb_context` decorator."""
ndb_client = None
ndb_context_kwargs = None
def __init__(self, *, ndb_client=None, ndb_context_kwargs=None, **kwargs):
"""Constructor.
Args:
ndb_client (google.cloud.ndb.Client): used when there isn't already
an ndb context active
ndb_context_kwargs (dict): optional, used when creating a new ndb context
"""
super().__init__(**kwargs)
assert ndb_client
self.ndb_client = ndb_client
self.ndb_context_kwargs = ndb_context_kwargs or {}
[docs]
def ndb_context(fn):
"""Enters an ndb context if one isn't already active.
Must be used on NdbMixin subclasses' methods.
"""
@wraps(fn)
def decorated(self, *args, **kwargs):
ctx = context.get_context(raise_context_error=False)
with ctx.use() if ctx else self.ndb_client.context(**self.ndb_context_kwargs):
ret = fn(self, *args, **kwargs)
return ret
return decorated
[docs]
class DatastoreSequences(storage.Sequences, NdbMixin):
"""Datastore-based sequence numbers.
Sequences are stored in :class:`AtpSequence`s.
"""
[docs]
@ndb_context
# propagation=context.TransactionOptions.INDEPENDENT is important here so that we
# don't include this in heavy, long-running commit transactions, since it's a
# single-row bottleneck! (the default is join=True.)
@ndb.transactional(propagation=context.TransactionOptions.INDEPENDENT, join=None)
def allocate(self, nsid):
"""Returns the next sequence number for a given NSID.
Creates a new :class:`AtpSequence` entity if one doesn't already exist
for the given NSID.
Args:
nsid (str): the subscription XRPC method for this sequence number
Returns:
integer, next sequence number for this NSID
"""
seq = AtpSequence.get_or_insert(nsid, next=1)
ret = seq.next
seq.next += 1
seq.put()
logger.info(f'allocated seq {ret}')
return ret
[docs]
@ndb_context
def last(self, nsid):
"""Returns the last allocated sequence number for a given NSID.
Args:
nsid (str): the subscription XRPC method for this sequence number
Returns:
integer, last sequence number for this NSID, or None if we don't know it
"""
if seq := AtpSequence.get_by_id(nsid):
return seq.next - 1
[docs]
class MemcacheSequences(storage.Sequences, NdbMixin):
"""Memcache-backed sequence numbers.
Allocates sequence numbers from memcache for better performance, backed by
:class:`AtpSequence`s in the datastore that allocate in batches.
Attributes:
* memcache (pymemcache.client.base.Client)
* max_seqs (dict): maps string nsid to integer max sequence number, the lower
bound on the AtpSequence's current value. this is the highest seq we can
allocate from memcache without allocating a new batch from the datastore and
updating the stored AtpSequence's value.
* max_seqs_lock (threading.Lock): for modifying max_seqs
"""
memcache = None
max_seqs = {}
max_seqs_lock = threading.Lock()
def __init__(self, *, memcache, **kwargs):
super().__init__(**kwargs)
assert memcache
self.memcache = memcache
[docs]
@ndb_context
def allocate(self, nsid):
"""Allocates a single sequence number from memcache.
The memcache key is ``[nsid]-last-seq``. Its value is the last sequence
number we've allocated.
Args:
nsid (str): the subscription XRPC method for this sequence number
Returns:
integer, next sequence number for this NSID
"""
assert MEMCACHE_SEQUENCE_BATCH > MEMCACHE_SEQUENCE_BUFFER > 1, \
(MEMCACHE_SEQUENCE_BATCH, MEMCACHE_SEQUENCE_BUFFER)
key = self._memcache_key(nsid)
seq = self.memcache.incr(key, 1)
if seq is None: # not in memcache
with self.max_seqs_lock:
# can't use last() because it looks in memcache
self.max_seqs[nsid] = AtpSequence.get_or_insert(nsid, next=1).next - 1
# we'll allocate a new batch below
if self.memcache.add(key, self.max_seqs[nsid]):
logger.info(f'initialized memcache sequence counter {key} to {self.max_seqs[nsid]}')
seq = self.memcache.incr(key, 1)
@ndb.transactional(propagation=context.TransactionOptions.INDEPENDENT,
join=None)
def alloc_batch():
stored_seq = AtpSequence.get_or_insert(nsid, next=1)
if stored_seq.next - seq < MEMCACHE_SEQUENCE_BUFFER:
stored_seq.next = seq + MEMCACHE_SEQUENCE_BATCH
logger.info(f'allocating {MEMCACHE_SEQUENCE_BATCH} seqs batch for {nsid}, up to {stored_seq.next}')
stored_seq.put()
self.max_seqs[nsid] = stored_seq.next
with self.max_seqs_lock:
if self.max_seqs.get(nsid, 0) - seq < MEMCACHE_SEQUENCE_BUFFER:
alloc_batch()
assert seq and seq <= self.max_seqs[nsid], (seq, self.max_seqs[nsid])
logger.info(f'allocated seq {seq}')
return seq
[docs]
@ndb_context
def last(self, nsid):
"""Returns the last allocated sequence number for a given NSID.
Args:
nsid (str): the subscription XRPC method for this sequence number
Returns:
integer, last sequence number for this NSID, or None if we don't know it
"""
val = self.memcache.get(self._memcache_key(nsid))
return int(val) if val else None
def _memcache_key(self, nsid):
"""Returns the sequence number memcache key for a given NSID."""
assert nsid
return f'{nsid}-last-seq'
[docs]
class AtpRemoteBlob(ndb.Model):
"""A blob available at a public HTTP URL that we don't store ourselves.
Key ID is the URL, truncated if necessary.
* https://atproto.com/specs/data-model#blob-type
* https://atproto.com/specs/xrpc#blob-upload-and-download
TODO:
* follow redirects, use final URL as key id
* abstract this in :class:`Storage`
"""
url = ndb.TextProperty()
'full length URL'
cid = ndb.StringProperty()
size = ndb.IntegerProperty()
mime_type = ndb.StringProperty(required=True, default='application/octet-stream')
repos = ndb.KeyProperty(repeated=True)
# only populated if mime_type is image/* or video/*
# used in images.aspectRatio in app.bsky.embed.images
# and aspectRatio in app.bsky.embed.video
width = ndb.IntegerProperty()
height = ndb.IntegerProperty()
# only populated if mime_type is video/*
# used to enforce maximum duration
duration = ndb.IntegerProperty()
'in ms'
last_fetched = ndb.DateTimeProperty(tzinfo=timezone.utc)
status = ndb.StringProperty(choices=('inactive',))
'None means active'
created = ndb.DateTimeProperty(auto_now_add=True)
updated = ndb.DateTimeProperty(auto_now=True)
[docs]
@classmethod
def get_or_create(cls, *, url=None, repo=None, get_fn=webutil.util.session.get,
max_size=None, accept_types=None, name=''):
"""Returns a new or existing :class:`AtpRemoteBlob` for a given URL.
If there isn't an existing :class:`AtpRemoteBlob`, or if the existing one
needs to be reloaded, fetches the URL over the network.
Args:
url (str)
repo (AtpRepo): optional
get_fn (callable): for making HTTP GET requests
max_size (int, optional): the ``maxSize`` parameter for this blob
field in its lexicon, if any
accept_types (sequence of str, optional): the ``accept`` parameter for
this blob field in its lexicon, if any. The set of allowed MIME types.
name (str, optional): blob field name in lexicon
Returns:
AtpRemoteBlob: existing or newly created blob
Raises:
requests.RequestException: if the HTTP request to fetch the blob failed
lexrpc.ValidationError: if the blob is over ``max_size``, its type is
not in ``accept_types`` or it is a video with a duration above the 3m
limit
"""
assert url
url_key = url
if len(url_key) > _MAX_KEYPART_BYTES:
# TODO: handle Unicode chars. naive approach is to UTF-8 encode,
# truncate, then decode, but that might cut mid character. easier to just
# hope/assume the URL is already URL-encoded.
url_key = url[:_MAX_KEYPART_BYTES]
logger.warning(f'Truncating URL {url} to {_MAX_KEYPART_BYTES} chars: {url_key}')
# if the blob already exists, just add this repo if necessary and return it
@ndb.transactional()
def get_or_insert():
repos = [repo.key] if repo else []
blob = cls.get_or_insert(url_key, repos=repos)
blob.url = url
if repo and repo.key not in blob.repos:
blob.repos.append(repo.key)
blob.put()
return blob
blob = get_or_insert()
if blob.status:
raise requests.HTTPError(f'Blob {url_key} is {blob.status}')
blob.maybe_fetch(get_fn=get_fn)
blob.validate(max_size=max_size, accept_types=accept_types, name=name)
return blob
[docs]
def maybe_fetch(self, get_fn=requests.get):
"""Fetches the blob from its URL and updates its metadata, if necessary.
Args:
get_fn (callable, optional): for making HTTP GET requests
"""
if ((self.cid or self.last_fetched)
and self.mime_type.split('/')[0] not in BLOB_REFETCH_TYPES):
# already fetched, and we don't refetch this type
return
elif (self.last_fetched
and self.last_fetched >= webutil.util.now() - BLOB_REFETCH_AGE):
# we've (re)fetched this recently
return
url = self.url or self.key.id()
logger.info(f'(Re)fetching blob URL {url}')
self.last_fetched = webutil.util.now()
try:
resp = get_fn(url, stream=True)
# if this is our first try, give up if it's not serving.
# otherwise, 4xx is conclusive; others like 5xx aren't
if resp.status_code // 100 == 4 or (not resp.ok and not self.cid):
logger.info('Marking blob inactive')
self.status = 'inactive'
resp.raise_for_status()
except OSError as e:
if not self.cid:
self.status = 'inactive'
raise requests.HTTPError(f"Couldn't fetch blob: {e}")
finally:
self.put()
# check type, size
self.mime_type = (resp.headers.get('Content-Type')
or mimetypes.guess_type(url)[0]
or 'application/octet-stream')
length = resp.headers.get('Content-Length')
logger.info(f'Got {resp.status_code} {self.mime_type} {length} bytes {resp.url}')
try:
length = self.size = int(length)
except (TypeError, ValueError):
pass # read body and check length manually below
if self.size and self.size > BLOB_MAX_BYTES:
self.put()
raise ValidationError(f'{url} Content-Length {length} is over BLOB_MAX_BYTES')
# calculate CID and update blob
digest = multihash.digest(resp.content, 'sha2-256')
self.cid = CID('base58btc', 1, 'raw', digest).encode('base32')
self.size = len(resp.content)
self.generate_metadata(resp.content)
self.status = None
self.put()
[docs]
def as_object(self):
"""Returns an ATProto ``blob`` object for this blob.
https://atproto.com/specs/data-model#blob-type
Returns:
dict or None: with ``$type: blob`` and ``ref``, ``mimeType``, and
``size`` fields. If :attr:`cid` is unset, returns None
"""
if self.cid:
return {
'$type': 'blob',
'ref': CID.decode(self.cid),
'mimeType': self.mime_type,
'size': self.size,
}
[docs]
def validate(self, max_size=None, accept_types=None, name=''):
"""Checks that this blob satisfies size and type constraints.
Args:
max_size (int, optional): the ``maxSize`` parameter for this blob
field in its lexicon, if any
accept_types (sequence of str, optional): the ``accept`` parameter for
this blob field in its lexicon, if any. The set of allowed MIME types.
name (str, optional): blob field name in lexicon
"""
url = self.url or self.key.id()
server.validate_mime_type(self.mime_type, accept_types, name=name)
if self.size:
if self.size > BLOB_MAX_BYTES:
raise ValidationError(f'{url} size {self.size} is over BLOB_MAX_BYTES')
elif max_size and self.size > max_size:
raise ValidationError(f'{url} size {self.size} is over {name} blob maxSize {max_size}')
if self.duration and timedelta(milliseconds=self.duration) > VIDEO_MAX_DURATION:
raise ValidationError(f'{url} duration {self.duration / 1000}s is over limit {VIDEO_MAX_DURATION}')
[docs]
class DatastoreStorage(Storage, NdbMixin):
"""Google Cloud Datastore implementation of :class:`Storage`.
Sequence numbers in :class:`AtpBlock`s are allocated per commit; all blocks
in a given commit will have the same sequence number.
See :class:`Storage` for method details.
"""
def __init__(self, *, sequences=None, ndb_client=None, ndb_context_kwargs=None):
"""Constructor.
Args:
sequences (Sequences): optional; defaults to a :class:`DatastoreSequences`
ndb_client (google.cloud.ndb.Client): used when there isn't already
an ndb context active
ndb_context_kwargs (dict): optional, used when creating a new ndb context
"""
if not sequences:
sequences = DatastoreSequences(ndb_client=ndb_client,
ndb_context_kwargs=ndb_context_kwargs)
super().__init__(sequences=sequences, ndb_client=ndb_client,
ndb_context_kwargs=ndb_context_kwargs)
@ndb_context
def create_repo(self, repo):
assert repo.did
assert repo.head
handles = [repo.handle] if repo.handle else []
signing_key_pem = repo.signing_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
rotation_key_pem = None
if repo.rotation_key:
rotation_key_pem = repo.rotation_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
atp_repo = AtpRepo(id=repo.did, handles=handles,
head=repo.head.cid.encode('base32'),
encrypted_signing_key=signing_key_pem,
encrypted_rotation_key=rotation_key_pem,
status=repo.status)
atp_repo.put()
logger.debug(f'Stored repo {atp_repo}')
@ndb_context
def load_repo(self, did_or_handle):
assert did_or_handle
atp_repo = (AtpRepo.get_by_id(did_or_handle)
or AtpRepo.query(AtpRepo.handles == did_or_handle).get())
if not atp_repo:
logger.debug(f"Couldn't find repo for {did_or_handle}")
return None
logger.debug(f'Loading repo {atp_repo.key}')
handle = atp_repo.handles[0] if atp_repo.handles else None
created = atp_repo.created.replace(tzinfo=timezone.utc) if atp_repo.created else None
return Repo.load(self, cid=CID.decode(atp_repo.head), handle=handle,
status=atp_repo.status, signing_key=atp_repo.signing_key,
rotation_key=atp_repo.rotation_key, created=created)
@ndb_context
def load_repos(self, after=None, limit=500, minimal=False):
query = AtpRepo.query()
if after:
query = query.filter(AtpRepo.key > AtpRepo(id=after).key)
# duplicates parts of Repo.load but batches reading blocks from storage
atp_repos = query.fetch(limit=limit)
cids = [CID.decode(r.head) for r in atp_repos]
blocks = self.read_many(cids) # dict mapping CID to block
heads = [blocks[cid] for cid in cids]
if minimal:
return [Repo(storage=self, head=head, status=atp_repo.status)
for atp_repo, head in zip(atp_repos, heads)]
# MST.load doesn't read from storage
msts = [MST.load(storage=self, cid=block.decoded['data']) for block in heads]
return [Repo(storage=self, mst=mst, head=head, status=atp_repo.status,
handle=atp_repo.handles[0] if atp_repo.handles else None,
signing_key=atp_repo.signing_key,
rotation_key=atp_repo.rotation_key)
for atp_repo, head, mst in zip(atp_repos, heads, msts)]
@ndb_context
def _set_repo_status(self, repo, status):
assert status in (DEACTIVATED, DELETED, TOMBSTONED, None)
repo.status = status # in memory only
@ndb.transactional(join=True)
def update():
atp_repo = AtpRepo.get_by_id(repo.did)
atp_repo.status = status
atp_repo.put()
update()
def store_repo(self, repo):
@ndb.transactional(join=True)
def store():
atp_repo = AtpRepo.get_by_id(repo.did)
atp_repo.handles = [repo.handle] if repo.handle else []
atp_repo.status = repo.status
atp_repo.head = repo.head.cid.encode('base32')
atp_repo.put()
logger.debug(f'Stored repo {atp_repo}')
store()
@ndb_context
@ndb.non_transactional()
def read(self, cid):
block = AtpBlock.get_by_id(cid.encode('base32'))
if block:
return block.to_block()
@ndb_context
@ndb.non_transactional()
def read_many(self, cids):
# defensive copy in case cids is a generator
cids = list(cids)
keys = [ndb.Key(AtpBlock, cid.encode('base32')) for cid in cids]
raw = ndb.get_multi(keys)
got = list(zip(cids, raw))
return {cid: block.to_block() if block else None for cid, block in got}
[docs]
@ndb_context
@ndb.non_transactional()
def read_many_raw(self, cids):
"""Fetches blocks via raw Datastore gRPC, returning only encoded bytes and seq.
Bypasses ndb model instantiation to reduce CPU overhead. Returns the raw
DAG-CBOR bytes directly from the Datastore protobuf with no decode/re-encode.
Args:
cids (sequence of bytes CIDs): raw binary CID bytes
Returns:
dict: {bytes CID: ``(encoded bytes, seq int)`` or None if not found}
"""
cid_list = list(cids)
result = {}
ctx = context.get_context()
client = ctx.client
database = client.database or ''
header_params = {'project_id': client.project}
if database:
header_params['database_id'] = database
metadata = (gapic_v1.routing_header.to_grpc_metadata(header_params),)
# Datastore key.name is a string, so encode binary CIDs to base32 once.
# build a reverse map so we can look up the original bytes from the
# base32 name in the response. (response order isn't guaranteed.)
#
# can't easily use libipld.decode_cid when processing results below because
# that only returns raw SHA-256 digest bytes, not full binary CIDs.
key_to_cid = {}
pending = []
for cid in cid_list:
name = libipld.encode_cid(cid)
key_to_cid[name] = cid
pending.append(entity_pb2.Key(
partition_id=entity_pb2.PartitionId(**header_params),
path=[entity_pb2.Key.PathElement(kind='AtpBlock', name=name)]))
while pending:
futures = [
client.stub.lookup.future(
# 1000 is the Datastore API limit per request
ds_pb2.LookupRequest(keys=pending[i:i + 1000], **header_params),
timeout=30,
metadata=metadata,
)
# TODO: switch to itertools.batched once Cloud Profiler supports
# Python 3.12 and we don't have to maintain 3.11 compatibility
for i in range(0, len(pending), 1000)
]
pending = []
for future in futures:
response = future.result()
for entity_result in response.found:
entity = entity_result.entity
name = entity.key.path[-1].name
encoded = bytes(entity.properties['encoded'].blob_value)
seq = entity.properties['seq'].integer_value
result[key_to_cid[name]] = (encoded, seq)
pending.extend(response.deferred)
return result
# can't use @ndb_context because this is a generator, not a normal function
def read_blocks_by_seq(self, start=0, repo=None):
assert start >= 0
cur_seq = start
cur_seq_cids = []
while True:
# lexrpc event subscription handlers like subscribeRepos call this
# on a different thread, so if we're there, we need to create a new
# ndb context
ctx = context.get_context(raise_context_error=False)
with (ctx.use() if ctx
else self.ndb_client.context(**self.ndb_context_kwargs)):
try:
query = AtpBlock.query(AtpBlock.seq >= cur_seq).order(AtpBlock.seq)
if repo:
query = query.filter(AtpBlock.repo == AtpRepo(id=repo).key)
# unproven hypothesis: need strong consistency to make sure we
# get all blocks for a given seq, including commit
# https://console.cloud.google.com/errors/detail/CO2g4eLG_tOkZg;service=atproto-hub;time=P1D;refresh=true;locations=global?project=bridgy-federated
#
# also: ndb queries (via gRPC) seem to have no default timeout,
# so in rare cases, it's maybe possible that this query can hang
# indefinitely? so we set an explicit timeout. background:
# https://github.com/snarfed/bridgy-fed/issues/2327
for atp_block in query.iter(read_consistency=ndb.STRONG,
timeout=QUERY_TIMEOUT.total_seconds()):
if atp_block.seq != cur_seq:
cur_seq = atp_block.seq
cur_seq_cids = []
if atp_block.key.id() not in cur_seq_cids:
cur_seq_cids.append(atp_block.key.id())
yield atp_block.to_block()
# finished cleanly
break
except ContextError as e:
logger.warning(f'lost ndb context! re-querying at {cur_seq}. {e}')
# continue loop, restart query
# Context.use() resets this to the previous context when it exits,
# but that context is bad now, so make sure we get a new one at the
# top of the loop
context._state.context = context._state.toplevel_context = None
@ndb_context
@ndb.non_transactional()
def has(self, cid):
return self.read(cid) is not None
@ndb_context
@ndb.non_transactional()
def write(self, repo_did, obj, seq=None):
if seq is None:
seq = self.sequences.allocate(SUBSCRIBE_REPOS_NSID)
return AtpBlock.create(repo_did=repo_did, data=obj, seq=seq).to_block()
@ndb_context
@ndb.non_transactional()
def write_blocks(self, blocks):
keys = [AtpBlock(id=b.cid.encode('base32')).key for b in blocks]
existing = AtpBlock.query(AtpBlock.key.IN(keys)).fetch(keys_only=True)
existing_cids = [key.id() for key in existing]
new = [AtpBlock.from_block(b) for b in blocks
if b.cid.encode('base32') not in existing_cids]
logger.debug(f' {len(new)} new {len(existing)} existing')
ndb.transaction(lambda: ndb.put_multi(new), join=True)
@ndb_context
@ndb.transactional(retries=10)
def _commit(self, *args, **kwargs):
"""Just runs :meth:`Storage._commit` in a transaction."""
return super()._commit(*args, **kwargs)