Source code for arroba.mst

"""Bluesky / AT Protocol Merkle search tree implementation.

* https://atproto.com/guides/data-repos
* https://atproto.com/lexicons/com-atproto-sync
* https://hal.inria.fr/hal-02303490/document

Heavily based on:
https://github.com/bluesky-social/atproto/blob/main/packages/repo/src/mst/mst.ts

Huge thanks to the Bluesky team for working in the public, in open source, and to
Daniel Holmgren and Devin Ivy for this code specifically!

From that file:

This is an implementation of a Merkle Search Tree (MST)
The data structure is described here: https://hal.inria.fr/hal-02303490/document
The MST is an ordered, insert-order-independent, deterministic tree.
Keys are laid out in alphabetic order.
The key insight of an MST is that each key is hashed and starting 0s are counted
to determine which layer it falls on (5 zeros for ~32 fanout).
This is a merkle tree, so each subtree is referred to by it's hash (CID).
When a leaf is changed, ever tree on the path to that leaf is changed as well,
thereby updating the root hash.

For atproto, we use SHA-256 as the key hashing algorithm, and ~4 fanout
(2-bits of zero per layer).

A couple notes on CBOR encoding:

There are never two neighboring subtrees.
Therefore, we can represent a node as an array of
leaves & pointers to their right neighbor (possibly null),
along with a pointer to the left-most subtree (also possibly null).

Most keys in a subtree will have overlap.
We do compression on prefixes by describing keys as:
* the length of the prefix that it shares in common with the preceding key
* the rest of the string

For example:

If the first leaf in a tree is ``bsky/posts/abcdefg`` and the second is
``bsky/posts/abcdehi``, then the first will be described as ``prefix: 0, key:
'bsky/posts/abcdefg'``, and the second will be described as ``prefix: 16, key:
'hi'``.
"""
from collections import namedtuple
import copy
from hashlib import sha256
import logging
from os.path import commonprefix
import re

import dag_cbor
from multiformats import CID

from .storage import Block, Storage
from .util import dag_cbor_cid

logger = logging.getLogger(__name__)

# this is treeEntry in mst.ts
Entry = namedtuple('Entry', [
    'p',  # int, length of prefix that this key shares with the prev key
    'k',  # bytes, the rest of the key outside the shared prefix
    'v',  # str CID, value
    't',  # str CID, next subtree (to the right of leaf), or None
])

Data = namedtuple('Data', [
    'l',  # str CID, left-most subtree, or None
    'e',  # list of Entry
])

Leaf = namedtuple('Leaf', [
    'key',    # str, record key
    'value',  # CID
])


[docs]class MST: """Merkle search tree class. Attributes: storage (Storage): entries (sequence of MST and Leaf) layer (int): this MST's layer in the root MST pointer (CID): outdated_pointer (bool): whether pointer needs to be recalculated """ storage = None entries = None layer = None pointer = None outdated_pointer = False
[docs] def __init__(self, *, storage=None, entries=None, pointer=None, layer=None): """Constructor. Args: storage (Storage) entries (sequence of MST and Leaf) pointer (CID) layer (int) Returns: MST: """ self.storage = storage self.entries = entries self.pointer = pointer self.layer = layer
@classmethod def load(cls, *, storage=None, cid=None): return MST(storage=storage, entries=None, pointer=cid, layer=None)
[docs] @classmethod def create(cls, *, storage=None, entries=None, layer=None): """ Args: storage (Storage) entries (sequence of MST and Leaf) layer (int) Returns: MST """ if not entries: entries = [] pointer = cid_for_entries(entries) return MST(storage=storage, entries=entries, pointer=pointer, layer=layer)
# def from_data(storage, data, opts): # """ # Returns: # MST: # """ # entries = deserialize_node_data(data) # pointer = cid_for_cbor(data) # return MST(entries=entries, pointer=pointer) def __eq__(self, other): if isinstance(other, MST): return self.get_pointer() == other.get_pointer() def __unicode__(self): return f'MST with pointer {self.get_pointer()}' def __repr__(self): return f'MST(storage={self.storage}, entries=..., pointer={self.get_pointer()}, layer={self.get_layer()})' # Immutability # -------------------
[docs] def new_tree(self, entries): """We never mutate an MST, we just return a new MST with updated values. Args: entries (sequence of MST and Leaf) Returns: MST: """ mst = MST(storage=self.storage, entries=entries, pointer=self.pointer, layer=self.layer) mst.outdated_pointer = True return mst
# Getters (lazy load) # -------------------
[docs] def get_entries(self): """ We don't want to load entries of every subtree, just the ones we need. Returns: sequence of MST and Leaf: """ if self.entries is not None: return copy.copy(self.entries) if self.pointer: data = Data(**self.storage.read(self.pointer).decoded) first_leaf = layer = None if data.e: layer = leading_zeros_on_hash(data.e[0]['k']) self.entries = deserialize_node_data(storage=self.storage, data=data, layer=layer) return self.entries raise RuntimeError('No entries or CID provided')
[docs] def get_pointer(self): """Returns this MST's root CID pointer. Calculates it if necessary. We don't hash the node on every mutation for performance reasons. Instead we keep track of whether the pointer is outdated and only (recursively) calculate when needed. Returns: CID: """ if not self.outdated_pointer: return self.pointer outdated = False entries = self.get_entries() for e in entries: if isinstance(e, MST) and e.outdated_pointer: outdated = True e.get_pointer() if outdated: entries = self.get_entries() self.pointer = cid_for_entries(entries) self.outdated_pointer = False return self.pointer
[docs] def get_layer(self): """Returns this MST's layer, and sets ``self.layer``. In most cases, we get the layer of a node from a hint on creation. In the case of the topmost node in the tree, we look for a key in the node & determine the layer. In the case where we don't find one, we recurse down until we do. If we still can't find one, then we have an empty tree and the node is layer 0. Returns: int: """ self.layer = self.attempt_get_layer() if self.layer is None: self.layer = 0 return self.layer
[docs] def attempt_get_layer(self): """Returns this MST's layer, and sets ``self.layer``. Returns: int or None: """ if self.layer is not None: return self.layer entries = self.get_entries() layer = layer_for_entries(entries) if layer is None: for entry in entries: if isinstance(entry, MST): child_layer = entry.attempt_get_layer() if child_layer is not None: layer = child_layer + 1 break if layer is not None: self.layer = layer return layer
# Core functionality # -------------------
[docs] def get_unstored_blocks(self): """Return the necessary blocks to persist the MST to repo storage. Returns: (CID root, dict mapping CID to Block) tuple: """ unstored = {} pointer = self.get_pointer() if self.storage.has(pointer): return pointer, unstored entries = self.get_entries() data = serialize_node_data(entries) block = Block(decoded=data._asdict()) unstored[block.cid] = block for entry in entries: if isinstance(entry, MST): _, blocks = entry.get_unstored_blocks() unstored.update(blocks) return pointer, unstored
[docs] def add(self, key, value=None, known_zeros=None): """Adds a new leaf for the given key/value pair. Args: key (str) value (CID) known_zeros (int) Returns: MST: Raises: ValueError: if a leaf with that key already exists """ ensure_valid_key(key) key_zeros = known_zeros or leading_zeros_on_hash(key) layer = self.get_layer() new_leaf = Leaf(key=key, value=value) if key_zeros == layer: # it belongs in self layer index = self.find_gt_or_equal_leaf_index(key) found = self.at_index(index) if isinstance(found, Leaf) and found.key == key: raise ValueError(f'There is already a value at key: {key}') prev_node = self.at_index(index - 1) if not prev_node or isinstance(prev_node, Leaf): # if entry before is a leaf, (or we're on far left) we can just splice in return self.splice_in(new_leaf, index) else: # else we try to split the subtree around the key left, right = prev_node.split_around(key) return self.replace_with_split(index - 1, left, new_leaf, right) elif key_zeros < layer: # it belongs on a lower layer index = self.find_gt_or_equal_leaf_index(key) prev_node = self.at_index(index - 1) if prev_node and isinstance(prev_node, MST): # if entry before is a tree, we add it to that tree new_subtree = prev_node.add(key, value, key_zeros) return self.update_entry(index - 1, new_subtree) else: sub_tree = self.create_child() new_subtree = sub_tree.add(key, value, key_zeros) return self.splice_in(new_subtree, index) else: # key_zeros > layer # it belongs on a higher layer, push the rest of the tree down left, right = self.split_around(key) # if the newly added key has >=2 more leading zeros than the current # highest layer then we need to add structural nodes between as well layer = self.get_layer() extra_layers_to_add = key_zeros - layer # intentionally starting at 1, first layer is taken care of by split for i in range(1, extra_layers_to_add): if left: left = left.create_parent() if right: right = right.create_parent() updated = [] if left: updated.append(left) updated.append(Leaf(key=key, value=value)) if right: updated.append(right) new_root = MST.create(storage=self.storage, entries=updated, layer=key_zeros) new_root.outdated_pointer = True return new_root
[docs] def get(self, key): """Gets the value at the given key. Args: key (str) Returns: CID or None: """ index = self.find_gt_or_equal_leaf_index(key) found = self.at_index(index) if found and isinstance(found, Leaf) and found.key == key: return found.value prev = self.at_index(index - 1) if prev and isinstance(prev, MST): return prev.get(key)
[docs] def update(self, key, value): """Edits the value at the given key. Args: key (str) value (CID) Returns: MST: Raises: KeyError: if key doesn't exist """ ensure_valid_key(key) index = self.find_gt_or_equal_leaf_index(key) found = self.at_index(index) if found and isinstance(found, Leaf) and found.key == key: return self.update_entry(index, Leaf(key=key, value=value)) prev = self.at_index(index - 1) if prev and isinstance(prev, MST): updated_tree = prev.update(key, value) return self.update_entry(index - 1, updated_tree) raise KeyError(f'Could not find a record with key: {key}')
[docs] def delete(self, key): """Deletes the value at the given key. Args: key (str) Returns: MST Raises: KeyError: if key doesn't exist """ return self.delete_recurse(key).trim_top()
[docs] def delete_recurse(self, key): """Deletes the value and subtree, if any, at the given key. Args: key (str): Returns: MST """ index = self.find_gt_or_equal_leaf_index(key) found = self.at_index(index) # if found, remove it on self level if isinstance(found, Leaf) and found.key == key: prev = self.at_index(index - 1) next = self.at_index(index + 1) if isinstance(prev, MST) and isinstance(next, MST): merged = prev.append_merge(next) return self.new_tree( self.slice(0, index - 1) + [merged] + self.slice(index + 2) ) else: return self.remove_entry(index) # else recurse down to find it prev = self.at_index(index - 1) if isinstance(prev, MST): subtree = prev.delete_recurse(key) if subtree.get_entries(): return self.update_entry(index - 1, subtree) else: return self.remove_entry(index - 1) raise KeyError(f'Could not find a record with key: {key}')
# Simple Operations # -------------------
[docs] def update_entry(self, index, entry): """Updates an entry in place. Args: index (int) entry (MST or Leaf) Returns: MST: """ return self.new_tree( entries=self.slice(0, index) + [entry] + self.slice(index + 1))
[docs] def remove_entry(self, index): """Removes the entry at a given index. Args: index (int) Returns: MST: """ return self.new_tree(entries=self.slice(0, index) + self.slice(index + 1))
[docs] def append(self, entry): """Appends an entry to the end of the node. Args: entry (MST or Leaf) Returns: MST: """ return self.new_tree(self.get_entries() + [entry])
[docs] def prepend(self, entry): """Prepends an entry to the start of the node. Args: entry (MST or Leaf) Returns: MST: """ return self.new_tree([entry] + self.get_entries())
[docs] def at_index(self, index): """Returns the entry at a given index. Args: index (int) Returns: MST or Leaf or None: """ entries = self.get_entries() if 0 <= index < len(entries): return entries[index]
[docs] def slice(self, start=None, end=None): """Returns a slice of this node. Args: start (int): optional, inclusive end (int): optional, exclusive Returns: sequence of MST and Leaf: """ return self.get_entries()[start:end]
[docs] def splice_in(self, entry, index): """Inserts an entry at a given index. Args: entry (MST or Leaf) index (int) Returns: MST: """ return self.new_tree(self.slice(0, index) + [entry] + self.slice(index))
[docs] def replace_with_split(self, index, left=None, leaf=None, right=None): """Replaces an entry with [ Maybe(tree), Leaf, Maybe(tree) ]. Args: index (int): left (MST or Leaf): leaf (Leaf): right (MST or Leaf): Returns: MST: """ updated = self.slice(0, index) if left: updated.append(left) updated.append(leaf) if right: updated.append(right) updated.extend(self.slice(index + 1)) return self.new_tree(updated)
[docs] def trim_top(self): """Trims the top and return its subtree, if necessary. Only if the topmost node in the tree only points to another tree. Otherwise, does nothing. Returns: MST: """ entries = self.get_entries() if len(entries) == 1 and isinstance(entries[0], MST): return entries[0].trim_top() else: return self
# Subtree & Splits # -------------------
[docs] def split_around(self, key): """Recursively splits a subtree around a given key. Args: key (str) Returns: (MST or None, MST or None) tuple: """ index = self.find_gt_or_equal_leaf_index(key) # split tree around key left_data = self.slice(0, index) right_data = self.slice(index) left = self.new_tree(left_data) right = self.new_tree(right_data) # if the far right of the left side is a subtree, # we need to split it on the key as well last_in_left = left_data[-1] if left_data else None if isinstance(last_in_left, MST): left = left.remove_entry(len(left_data) -1) split = last_in_left.split_around(key) if split[0]: left = left.append(split[0]) if split[1]: right = right.prepend(split[1]) return [ left if left.get_entries() else None, right if right.get_entries() else None, ]
[docs] def append_merge(self, to_merge): """Merges another tree with this one. The simple merge case where every key in the right tree is greater than every key in the left tree. Used primarily for deletes. Args: to_merge (MST) Returns: MST: """ assert self.get_layer() == to_merge.get_layer(), \ 'Trying to merge two nodes from different layers of the MST' self_entries = self.get_entries() to_merge_entries = to_merge.get_entries() last_in_left = self_entries[-1] first_in_right = to_merge_entries[0] if isinstance(last_in_left, MST) and isinstance(first_in_right, MST): merged = last_in_left.append_merge(first_in_right) return self.new_tree( list(self_entries[:-1]) + [merged] + to_merge_entries[1:]) else: return self.new_tree(self_entries + to_merge_entries)
# Create relatives # -------------------
[docs] def create_child(self): """ Returns: MST: """ return MST.create(storage=self.storage, entries=[], layer=self.get_layer() - 1)
[docs] def create_parent(self): """ Returns: MST: """ parent = MST.create(storage=self.storage, entries=[self], layer=self.get_layer() + 1) parent.outdated_pointer = True return parent
# Finding insertion points # -------------------
[docs] def find_gt_or_equal_leaf_index(self, key): """Finds the index of the first leaf node greater than or equal to value. Args: key (str) Returns: int: """ entries = self.get_entries() for i, entry in enumerate(entries): if isinstance(entry, Leaf) and entry.key >= key: return i # if we can't find it, we're on the end return len(entries)
# List operations (partial tree traversal) # -------------------
[docs] def walk_leaves_from(self, key): """Walk tree starting at key. Generator for leaves in the tree, starting at a given rkey. Args: key (str): Generates: Leaf """ index = self.find_gt_or_equal_leaf_index(key) entries = self.get_entries() if index > 0: prev = entries[index - 1] if prev and isinstance(prev, MST): for e in prev.walk_leaves_from(key): yield e for entry in entries[index:]: if isinstance(entry, Leaf): yield entry else: for e in entry.walk_leaves_from(key): yield e
[docs] def list(self, after=None, before=None): """Returns entries, optionally bounded within an rkey range. Args: after (str): rkey, optional before (str): rkey, optional Returns: sequence of Leaf: """ vals = [] for leaf in self.walk_leaves_from(after or ''): if leaf.key == after: continue if before and leaf.key >= before: break vals.append(leaf) return vals
[docs] def list_with_prefix(self, prefix): """Returns entries with a given rkey prefix. Args: prefix (str): rkey prefix Returns: sequence of Leaf """ vals = [] for leaf in self.walk_leaves_from(prefix): if not leaf.key.startswith(prefix): break vals.append(leaf) return vals
# Full tree traversal # -------------------
[docs] def walk(self): """Walk full tree, depth first, and emit nodes. Returns: generator of MST and Leaf: """ yield self for entry in self.get_entries(): if isinstance(entry, MST): for e in entry.walk(): yield e else: yield entry
# Walk full tree & emit nodes, consumer can bail at any point by returning False # def paths(): # """ # Returns: # sequence of MST and Leaf # """ # paths = [] # for entry in self.get_entries(): # if isinstance(entry, Leaf): # paths.append([entry]) # if isinstance(entry, MST): # sub_paths = entry.paths() # paths.extend([entry] + p for p in sub_paths) # # return paths
[docs] def all_nodes(self): """Walks the tree and returns all nodes. Returns: sequence of MST and Leaf: """ return list(self.walk())
# Walks tree & returns all cids # def all_cids(): # """ # Returns: # CidSet # """ # cids = CidSet() # for entry in self.get_entries(): # if isinstance(entry, Leaf): # cids.add(entry.value) # else: # subtree_cids = entry.all_cids() # cids.add_set(subtree_cids) # cids.add(self.get_pointer()) # return cids
[docs] def leaves(self): """Walks tree and returns all leaves. Returns: sequence of Leaf: """ return [entry for entry in self.walk() if isinstance(entry, Leaf)]
[docs] def leaf_count(self): """Returns the total number of leaves in this MST. Returns: int: """ return len(self.leaves())
# Reachable tree traversal # ------------------- # Walk reachable branches of tree & emit nodes, consumer can bail at any # point by returning False # def walk_reachable(): AsyncIterable<NodeEntry>: # yield self # for entry in self.get_entries(): # if isinstance(entry, MST): # try: # for e in entry.walk_reachable(): # yield e # catch (err): # if err instanceof MissingBlockError: # continue # else: # raise err # else: # yield entry # def reachable_leaves(): # """ # Returns: # Leaf[] # """ # leaves: Leaf[] = [] # for entry in self.walk_reachable(): # if isinstance(entry, Leaf): # leaves.append(entry) # return leaves # Sync Protocol
[docs] def load_all(self): """Generator. Used in :func:`xrpc_sync.get_checkout`. (The bluesky-social/atproto TS code calls this ``writeToCarStream``.) Returns: generator of (CID, bytes) tuples """ leaves = set() # CIDs to_fetch = set() # CIDs pointer = self.get_pointer() assert pointer to_fetch.add(pointer) while to_fetch: blocks = self.storage.read_many(to_fetch) to_fetch.clear() for cid, block in blocks.items(): yield cid, block.encoded entries = deserialize_node_data(storage=self.storage, data=Data(**block.decoded)) for entry in entries: if isinstance(entry, Leaf): leaves.add(entry.value) else: to_fetch.add(entry.get_pointer()) leaf_blocks = self.storage.read_many(leaves) for cid, block in leaf_blocks.items(): yield cid, block.encoded
# def cids_for_path(self, key): # """Returns the CIDs in a given key path. ??? # # Args: # key (str): # # Returns: # sequence of :class:`CID` # """ # cids: CID[] = [self.get_pointer()] # index = self.find_gt_or_equal_leaf_index(key) # found = self.at_index(index) # if found and isinstance(found, Leaf) and found.key == key: # return cids + [found.value] # prev = self.at_index(index - 1) # if prev and isinstance(prev, MST): # return cids + prev.cids_for_path(key) # return cids
[docs]def leading_zeros_on_hash(key): """Returns the number of leading zeros in a key's hash. Args: key (str or bytes) Returns: int: """ if not isinstance(key, bytes): key = key.encode() # ensure_valid_key enforces that this is ASCII only leading_zeros = 0 for byte in sha256(key).digest(): if byte < 64: leading_zeros += 1 if byte < 16: leading_zeros += 1 if byte < 4: leading_zeros += 1 if byte == 0: leading_zeros += 1 else: break return leading_zeros
[docs]def layer_for_entries(entries): """ Args: entries (MST or Leaf) Returns: int or None: """ for entry in entries: if isinstance(entry, Leaf): return leading_zeros_on_hash(entry.key)
[docs]def deserialize_node_data(*, storage=None, data=None, layer=None): """ Args: storage (Storage) data (Data) Returns: sequence of MST and Leaf: """ entries = [] if (data.l is not None): entries.append(MST(storage=storage, pointer=data.l, layer=layer - 1 if layer else None)) last_key = '' for entry_data in data.e: entry = Entry(**entry_data) key_str = entry.k.decode() key = last_key[:entry.p] + key_str ensure_valid_key(key) entries.append(Leaf(key, entry.v)) last_key = key if entry.t is not None: entries.append(MST(storage=storage, pointer=entry.t, layer=layer - 1 if layer else None)) return entries
[docs]def serialize_node_data(entries): """ Args: entries (sequence of MST and Leaf) Returns: Data: """ l = None i = 0 if entries and isinstance(entries[0], MST): i += 1 l = entries[0].get_pointer() data = Data(l=l, e=[]) last_key = '' while i < len(entries): leaf = entries[i] next = entries[i + 1] if i < len(entries) - 1 else None if not isinstance(leaf, Leaf): raise ValueError('Not a valid node: two subtrees next to each other') i += 1 subtree = None if next and isinstance(next, MST): subtree = next.get_pointer() i += 1 ensure_valid_key(leaf.key) prefix_len = common_prefix_len(last_key, leaf.key) data.e.append(Entry( p=prefix_len, k=leaf.key[prefix_len:].encode('ascii'), v=leaf.value, t=subtree, )._asdict()) last_key = leaf.key return data
[docs]def common_prefix_len(a, b): """ Args: a (str) b (str) Returns: int: """ return len(commonprefix((a, b)))
[docs]def cid_for_entries(entries): """ Args: entries (sequence of MST and Leaf) Returns: CID """ return dag_cbor_cid(serialize_node_data(entries)._asdict())
[docs]def ensure_valid_key(key): """ Args: key (str) Raises: ValueError: if key is not a valid MST key """ valid = re.compile('[a-zA-Z0-9_\-:.]*$') split = key.split('/') if not (len(key) <= 256 and len(split) == 2 and split[0] and split[1] and valid.match(split[0]) and valid.match(split[1]) ): raise ValueError(f'Invalid MST key: {key}')
WalkStatus = namedtuple('WalkStatus', [ 'done', # bool 'cur', # MST or Leaf 'walking', # MST or None if cur is the root of the tree 'index', # int ], defaults=[None, None, None, None])
[docs]class Walker: """Allows walking an MST manually. Attributes: stack (sequence of WalkStatus) status (WalkStatus): current """ stack = None status = None
[docs] def __init__(self, tree): """Constructor. Args: tree (MST) """ self.stack = [] self.status = WalkStatus( done=False, cur=tree, walking=None, index=0, )
[docs] def layer(self): """Returns the curent layer of the node we're on.""" assert not self.status.done, 'Walk is done' if self.status.walking: return self.status.walking.layer or 0 # if cur is the root of the tree, add 1 if isinstance(self.status.cur, MST): return (self.status.cur.layer or 0) + 1 raise RuntimeError('Could not identify layer of walk')
[docs] def step_over(self): """Moves to the next node in the subtree, skipping over the subtree.""" if self.status.done: return # if stepping over the root of the node, we're done if not self.status.walking: self.status = WalkStatus(done=True) return entries = self.status.walking.get_entries() self.status = self.status._replace(index=self.status.index + 1) if self.status.index >= len(entries): if not self.stack: self.status = WalkStatus(done=True) else: self.status = self.stack.pop() self.step_over() else: self.status = self.status._replace(cur=entries[self.status.index])
[docs] def step_into(self): """Steps into a subtree. Raises: RuntimeError: if curently on a leaf """ if self.status.done: return # edge case for very start of walk if not self.status.walking: assert isinstance(self.status.cur, MST), \ 'The root of the tree cannot be a leaf' next = self.status.cur.at_index(0) if not next: self.status = WalkStatus(done=True) else: self.status = WalkStatus( done=False, walking=self.status.cur, cur=next, index=0, ) return if not isinstance(self.status.cur, MST): raise RuntimeError('No tree at pointer, cannot step into') next = self.status.cur.at_index(0) assert next, 'Tried to step into a node with 0 entries which is invalid' self.stack.append(self.status) self.status = WalkStatus( walking=self.status.cur, cur=next, index=0, done=False, )
[docs] def advance(self): """Advances to the next node in the tree. Steps into the curent node if necessary. """ if self.status.done: return if isinstance(self.status.cur, Leaf): self.step_over() else: self.step_into()