"""AT Protocol utility for diffing two MSTs.
Heavily based on:
https://github.com/bluesky/atproto/blob/main/packages/repo/src/mst/diff.ts
Huge thanks to the Bluesky team for working in the public, in open source, and to
Daniel Holmgren and Devin Ivy for this code specifically!
"""
from collections import namedtuple
import logging
from .mst import Leaf, MST, Walker
logger = logging.getLogger(__name__)
[docs]def mst_diff(cur, prev=None):
"""Generates a diff between two MSTs.
Args:
cur (MST)
prev (MST): optional
Returns:
Diff:
"""
cur.get_pointer()
if not prev:
return null_diff(cur)
prev.get_pointer()
diff = Diff()
left_walker = Walker(prev)
right_walker = Walker(cur)
while not left_walker.status.done or not right_walker.status.done:
# if one walker is finished, continue walking the other & logging all nodes
if left_walker.status.done and not right_walker.status.done:
node = right_walker.status.cur
if isinstance(node, Leaf):
diff.record_add(node.key, node.value)
else:
diff.record_new_cid(node.pointer)
right_walker.advance()
continue
elif not left_walker.status.done and right_walker.status.done:
node = left_walker.status.cur
if isinstance(node, Leaf):
diff.record_delete(node.key, node.value)
else:
diff.record_removed_cid(node.pointer)
left_walker.advance()
continue
if left_walker.status.done or right_walker.status.done:
break
left = left_walker.status.cur
right = right_walker.status.cur
if not left or not right:
break
# if both pointers are leaves, record an update & advance both or record
# the lowest key and advance that pointer
if isinstance(left, Leaf) and isinstance(right, Leaf):
if left.key == right.key:
if left.value != right.value:
diff.record_update(left.key, left.value, right.value)
left_walker.advance()
right_walker.advance()
elif left.key < right.key:
diff.record_delete(left.key, left.value)
left_walker.advance()
else:
diff.record_add(right.key, right.value)
right_walker.advance()
continue
# next, ensure that we're on the same layer
#
# if one walker is at a higher layer than the other, we need to do one
# of two things if the higher walker is pointed at a tree, step into
# that tree to try to catch up with the lower if the higher walker is
# pointed at a leaf, then advance the lower walker to try to catch up
# the higher
if left_walker.layer() > right_walker.layer():
if isinstance(left, Leaf):
if isinstance(right, Leaf):
diff.record_add(right.key, right.value)
else:
diff.record_new_cid(right.pointer)
right_walker.advance()
else:
diff.record_removed_cid(left.pointer)
left_walker.step_into()
continue
elif left_walker.layer() < right_walker.layer():
if isinstance(right, Leaf):
if isinstance(left, Leaf):
diff.record_delete(left.key, left.value)
else:
diff.record_removed_cid(left.pointer)
left_walker.advance()
else:
diff.record_new_cid(right.pointer)
right_walker.step_into()
continue
# if we're on the same level, and both pointers are trees, do a
# comparison. if they're the same, step over. if they're different, step
# in to find the subdiff
if isinstance(left, MST) and isinstance(right, MST):
if left.pointer == right.pointer:
left_walker.step_over()
right_walker.step_over()
else:
diff.record_new_cid(right.pointer)
diff.record_removed_cid(left.pointer)
left_walker.step_into()
right_walker.step_into()
continue
# finally, if one pointer is a tree and the other is a leaf, simply step
# into the tree
if isinstance(left, Leaf) and isinstance(right, MST):
diff.record_new_cid(right.pointer)
right_walker.step_into()
continue
elif isinstance(left, MST) and isinstance(right, Leaf):
diff.record_removed_cid(left.pointer)
left_walker.step_into()
continue
raise RuntimeError('Unidentifiable case in diff walk')
return diff
[docs]def null_diff(tree):
"""Generates a "null" diff for a single MST with all adds and new CIDs.
Args:
tree (MST)
Returns:
Diff:
"""
diff = Diff()
for entry in tree.walk():
if isinstance(entry, Leaf):
diff.record_add(entry.key, entry.value)
else:
diff.record_new_cid(entry.pointer)
return diff
Change = namedtuple('Change', [
'key', # str
'cid', # CID
'prev', # CID
], defaults=[None])
[docs]class Diff:
"""A diff between two MSTs.
Attributes:
adds (dict): maps str to :class:`Change`
updates (dict): maps str to :class:`Change`
deletes (dict): maps str to :class:`Change`
new_cids (set of :class:`CID`)
removed_cids (set of :class:`CID`)
"""
[docs] def __init__(self):
self.adds = {}
self.updates = {}
self.deletes = {}
self.new_cids = set()
self.removed_cids = set()
[docs] @staticmethod
def of(cur, prev=None):
"""
Args:
cur (MST)
prev (MST): optional
Returns:
Diff:
"""
return mst_diff(cur, prev)
[docs] def record_add(self, key, cid):
"""
Args:
key (str)
cid (CID)
"""
self.adds[key] = Change(key=key, cid=cid)
self.new_cids.add(cid)
[docs] def record_update(self, key, prev, cid):
"""
Args:
key (str)
prev (CID)
cid (CID)
"""
self.updates[key] = Change(key=key, cid=cid, prev=prev)
self.new_cids.add(cid)
[docs] def record_delete(self, key, cid):
"""
Args:
key (str)
cid (CID)
"""
self.deletes[key] = Change(key=key, cid=cid)
[docs] def record_new_cid(self, cid):
"""
Args:
cid (CID)
"""
if cid in self.removed_cids:
self.removed_cids.remove(cid)
else:
self.new_cids.add(cid)
[docs] def record_removed_cid(self, cid):
"""
Args:
cid (CID)
"""
if cid in self.new_cids:
self.new_cids.remove(cid)
else:
self.removed_cids.add(cid)
[docs] def add_diff(self, diff):
"""
Args:
diff (Diff)
"""
for add in diff.adds.values():
if self.deletes[add.key]:
deleted = self.deletes[add.key]
if deleted.cid != add.cid:
self.record_update(add.key, deleted.cid, add.cid)
del self.deletes[add.key]
else:
self.record_add(add.key, add.cid)
for update in diff.updates.values():
self.record_update(update.key, update.prev, update.cid)
del self.adds[update.key]
del self.deletes[update.key]
for deleted in diff.deletes.values():
if self.adds[deleted.key]:
del self.adds[deleted.key]
else:
del self.updates[deleted.key]
self.record_delete(deleted.key, deleted.cid)
self.new_cids |= diff.new_cids
def updated_keys(self):
return self.adds | self.updates | self.deletes