Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions iavl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,21 @@
from hexbytes import HexBytes

from . import dbm, diff
from .iavl import NodeDB, Tree
from .utils import (decode_fast_node, diff_iterators, encode_stdint,
fast_node_key, get_node, get_root_node,
iavl_latest_version, iter_fast_nodes, iter_iavl_tree,
load_commit_infos, root_key, store_prefix)
from .iavl import NodeDB, Tree, delete_version
from .utils import (
decode_fast_node,
diff_iterators,
encode_stdint,
fast_node_key,
get_node,
get_root_node,
iavl_latest_version,
iter_fast_nodes,
iter_iavl_tree,
load_commit_infos,
root_key,
store_prefix,
)
from .visualize import visualize_iavl, visualize_pruned_nodes


Expand Down Expand Up @@ -417,14 +427,14 @@ def test_state_round_trip(db, store, start_version):
tree = Tree(ndb, pversion)
diff.apply_change_set(tree, changeset)
tmp = tree.save_version(dry_run=True)
if (root.hash or hashlib.sha256().digest()) == tmp:
if (root or hashlib.sha256().digest()) == tmp:
print(v, len(changeset), "ok")
else:
print(
v,
len(changeset),
"fail",
binascii.hexlify(root.hash).decode(),
binascii.hexlify(root).decode(),
binascii.hexlify(tmp).decode(),
)

Expand All @@ -433,7 +443,7 @@ def iter_state_changes(
db: dbm.DBM, ndb: NodeDB, start_version=0, end_version=None, prefix=b""
):
pversion = ndb.prev_version(start_version) or 0
prev_root = ndb.get_root_node(pversion)
prev_root = ndb.get_root_hash(pversion)
it = db.iteritems()
it.seek(prefix + root_key(start_version))
for k, hash in it:
Expand All @@ -443,11 +453,10 @@ def iter_state_changes(
if end_version is not None and v >= end_version:
break

root = ndb.get(hash)
yield pversion, v, root, diff.state_changes(ndb.get, prev_root, root)
yield pversion, v, hash, diff.state_changes(ndb.get, pversion, prev_root, hash)

pversion = v
prev_root = root
prev_root = hash


@cli.command()
Expand All @@ -469,8 +478,8 @@ def visualize_pruning(db, store, version):
ndb = NodeDB(db, prefix=prefix)
predecessor = ndb.prev_version(version) or 0
successor = ndb.next_version(version)
root1 = ndb.get_root_node(version)
root2 = ndb.get_root_node(successor)
root1 = ndb.get_root_hash(version)
root2 = ndb.get_root_hash(successor)

touched_nodes = set()

Expand All @@ -479,14 +488,14 @@ def trace_get(hash):
return ndb.get(hash)

deleted = set()
for orphaned, _ in diff.diff_tree(
for n in delete_version(
trace_get,
version,
predecessor,
root1,
root2,
diff.DiffOptions.for_pruning(predecessor),
):
for n in orphaned:
deleted.add(n.hash)
deleted.add(n.hash)

print(
"delete version:",
Expand All @@ -504,7 +513,6 @@ def trace_get(hash):
len(touched_nodes),
file=sys.stderr,
)
touched_nodes.update([root1.hash, root2.hash])
g = visualize_pruned_nodes(successor, touched_nodes, deleted, ndb)
print(g.source)

Expand Down
228 changes: 36 additions & 192 deletions iavl/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
tree diff algorithm between two versions
"""
import binascii
import itertools
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Callable, List, NamedTuple, Optional, Tuple
from typing import List, Tuple

from cprotobuf import Field, ProtoEntity, decode_primitive, encode_primitive

from .iavl import PersistedNode, Tree

GetNode = Callable[bytes, Optional[PersistedNode]]
from .utils import GetNode, visit_iavl_nodes


class Op(IntEnum):
Expand All @@ -22,182 +19,6 @@ class Op(IntEnum):
ChangeSet = List[Change]


@dataclass
class Layer:
"""
Represent one layer of nodes at the same height

pending_nodes: because one of the children's height could be height-2, need to keep
it in the pending list temporarily.
"""

height: int = 0
nodes: List[PersistedNode] = field(default_factory=list)
pending_nodes: List[PersistedNode] = field(default_factory=list)

@classmethod
def root(cls, root):
return cls(
height=root.height,
nodes=[root],
)

@classmethod
def empty(cls, height):
return cls(height=height)

def next_layer(self, get_node: GetNode, predecessor):
"""
travel to next layer
"""
assert self.height > 0
nodes = []
pending_nodes = []
for node in self.nodes:
left = get_node(node.left_node_ref)
if left.version > predecessor:
if left.height == self.height - 1:
nodes.append(left)
else:
pending_nodes.append(left)

right = get_node(node.right_node_ref)
if right.version > predecessor:
if right.height == self.height - 1:
nodes.append(right)
else:
pending_nodes.append(right)

self.height -= 1

# merge sorted lists
self.nodes = nodes
self.nodes += self.pending_nodes
self.nodes.sort(key=lambda n: n.key)
self.pending_nodes = pending_nodes

def is_empty(self):
return not self.nodes and not self.pending_nodes


def diff_sorted(nodes1, nodes2):
"""
Contract: input list is sorted by node.key
return: (common, orphaned, new)
"""
i1 = i2 = 0
common = []
orphaned = []
new = []
while True:
if i1 > len(nodes1) - 1:
new += nodes2[i2:]
break
if i2 > len(nodes2) - 1:
orphaned += nodes1[i1:]
break
k1 = nodes1[i1].key
k2 = nodes2[i2].key
if nodes1[i1].hash == nodes2[i2].hash:
common.append(nodes1[i1])
i1 += 1
i2 += 1
elif k1 == k2:
# overriden by same key
orphaned.append(nodes1[i1])
new.append(nodes2[i2])
i1 += 1
i2 += 1
elif k1 < k2:
# proceed to next node in nodes1 until catch up with nodes2
orphaned.append(nodes1[i1])
i1 += 1
else:
# proceed to next node in nodes2 until catch up with nodes1
new.append(nodes2[i2])
i2 += 1
return common, orphaned, new


class DiffOptions(NamedTuple):
# predecessor will skip the subtrees at or before the predecessor from both trees.
predecessor: int
# in prune mode, the diff process stop as soon as orphaned nodes becomes empty.
prune_mode: bool

@classmethod
def full(cls):
"do a full diff, can be used for extracting state changes"
return cls(predecessor=0, prune_mode=False)

@classmethod
def for_pruning(cls, predecessor: int):
"do an optimized diff for pruning versions"
return cls(predecessor=predecessor, prune_mode=True)


def diff_tree(
get_node: GetNode, root1: PersistedNode, root2: PersistedNode, opts: DiffOptions
):
"""
diff two versions of the iavl tree.
yields (orphaned, new)

predecessor can help to skip more subtrees when finding orphaned nodes, we don't
need to traverse the subtrees that's created at or before predecessor in that case.
"""

# skipping nodes created at or before predecessor
if root1 is not None and root1.version <= opts.predecessor:
root1 = None
if root2 is not None and root2.version <= opts.predecessor:
root2 = None

# nothing to do if both tree are empty
if root1 is None and root2 is None:
return

# if one is empty, create an empty layer with the same height as the other tree.
if root1 is None:
l1 = Layer.empty(root2.height)
l2 = Layer.root(root2)
elif root2 is None:
l1 = Layer.root(root1)
l2 = Layer.empty(root1.height)
else:
l1 = Layer.root(root1)
l2 = Layer.root(root2)

while l1.height > l2.height:
yield l1.nodes, []
l1.next_layer(get_node, opts.predecessor)

while l2.height > l1.height:
yield [], l2.nodes
l2.next_layer(get_node, opts.predecessor)

while True:
# l1 l2 at the same height now
_, orphaned, new = diff_sorted(l1.nodes, l2.nodes)

yield orphaned, new

if l1.height == 0:
break

# don't visit the common sub-trees
l1.nodes = orphaned
l2.nodes = new

if opts.prune_mode and l1.is_empty():
# nothing else to see in tree1, no more orphaned nodes, only new ones,
# that's enough for pruning mode.
break

l1.next_layer(get_node, opts.predecessor)
l2.next_layer(get_node, opts.predecessor)


def split_operations(nodes1, nodes2) -> ChangeSet:
"""
Contract: input nodes are all leaf nodes, sorted by node.key
Expand Down Expand Up @@ -237,25 +58,48 @@ def split_operations(nodes1, nodes2) -> ChangeSet:
return result


def state_changes(get_node: GetNode, root1: PersistedNode, root2: PersistedNode):
def state_changes(get_node: GetNode, version, root, successor_root):
"""
extract state changes from the tree diff result
extract state changes from two versions of the iavl tree.

first traverse the successor version to find the shared sub-root nodes
and new leaf nodes, then traverse the target version to find the orphaned leaf
nodes, then extract kv pair operations from it.

return: [(key, op, arg)]
arg: original value if op==Delete
new value if op==Insert
(original value, new value) if op==Update
"""
for orphaned, new in diff_tree(get_node, root1, root2, DiffOptions.full()):
# the nodes are on the same height, and we only care about leaf nodes here
try:
node = next(itertools.chain(orphaned, new))
except StopIteration:
continue

if node.height == 0:
return split_operations(orphaned, new)
return []
shared = set()
new = []
if successor_root:

def successor_prune(n: PersistedNode) -> (bool, bool):
b = n.version <= version
return b, b

for n in visit_iavl_nodes(get_node, successor_prune, successor_root):
if n.version <= version:
shared.add(n.hash)
elif n.is_leaf():
new.append(n)

def prune(n: PersistedNode) -> (bool, bool):
b = n.hash in shared
return b, b

if root:
orphaned = [
n
for n in visit_iavl_nodes(get_node, prune, root)
if n.is_leaf() and n.hash not in shared
]
else:
orphaned = []

return split_operations(orphaned, new)


def apply_change_set(tree: Tree, changeset: ChangeSet):
Expand Down
Loading