Skip to content

Commit 79372d7

Browse files
committed
Only update IntervalTrees when we need them
1 parent 6c11d2e commit 79372d7

File tree

5 files changed

+168
-41
lines changed

5 files changed

+168
-41
lines changed

python/gtirb/byteinterval.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import itertools
21
import typing
32
from uuid import UUID
43

5-
from intervaltree import IntervalTree
64
from sortedcontainers import SortedDict
75

86
from .block import ByteBlock, CodeBlock, DataBlock
7+
from .lazyintervaltree import LazyIntervalTree
98
from .node import Node, _NodeMessage
109
from .proto import ByteInterval_pb2, SymbolicExpression_pb2
1110
from .symbolicexpression import SymAddrAddr, SymAddrConst, SymbolicExpression
@@ -162,15 +161,19 @@ def __init__(
162161
raise ValueError("initialized_size must be <= size!")
163162

164163
super().__init__(uuid=uuid)
165-
self._interval_tree: "IntervalTree[int, ByteBlock]" = IntervalTree()
166164
self._section: typing.Optional["Section"] = None
167165
self.address = address
168166
self.size = size
169167
self.contents = bytearray(contents)
170168
self.initialized_size = initialized_size
171-
self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet(
172-
self, blocks
169+
170+
# Both blocks and _interval_tree must exist before adding any blocks.
171+
self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet(self)
172+
self._interval_tree = LazyIntervalTree[int, ByteBlock](
173+
self.blocks, _offset_interval
173174
)
175+
self.blocks.update(blocks)
176+
174177
self._symbolic_expressions = ByteInterval._SymbolicExprDict(
175178
self, symbolic_expressions
176179
)
@@ -186,20 +189,14 @@ def _index_add_multiple(
186189
old_blocks: typing.Collection[ByteBlock],
187190
new_blocks: typing.Collection[ByteBlock],
188191
) -> None:
189-
if len(old_blocks) < len(new_blocks):
190-
self._interval_tree = IntervalTree(
191-
_offset_interval(block)
192-
for block in itertools.chain(old_blocks, new_blocks)
193-
)
194-
else:
195-
for block in new_blocks:
196-
self._index_add(block)
192+
for block in new_blocks:
193+
self._interval_tree.add(block)
197194

198195
def _index_add(self, block: ByteBlock) -> None:
199-
self._interval_tree.add(_offset_interval(block))
196+
self._interval_tree.add(block)
200197

201198
def _index_discard(self, block: ByteBlock) -> None:
202-
self._interval_tree.discard(_offset_interval(block))
199+
self._interval_tree.discard(block)
203200

204201
@property
205202
def initialized_size(self) -> int:
@@ -444,7 +441,7 @@ def byte_blocks_on(
444441
return ()
445442

446443
return _nodes_on_interval_tree(
447-
self._interval_tree, addrs, -self.address
444+
self._interval_tree.get(), addrs, -self.address
448445
)
449446

450447
def byte_blocks_at(
@@ -460,7 +457,7 @@ def byte_blocks_at(
460457
return ()
461458

462459
return _nodes_at_interval_tree(
463-
self._interval_tree, addrs, -self.address
460+
self._interval_tree.get(), addrs, -self.address
464461
)
465462

466463
def code_blocks_on(
@@ -524,7 +521,9 @@ def byte_blocks_on_offset(
524521
:param offsets: Either a ``range`` object or a single offset.
525522
"""
526523

527-
return _nodes_on_interval_tree_offset(self._interval_tree, offsets)
524+
return _nodes_on_interval_tree_offset(
525+
self._interval_tree.get(), offsets
526+
)
528527

529528
def byte_blocks_at_offset(
530529
self, offsets: typing.Union[int, range]
@@ -535,7 +534,9 @@ def byte_blocks_at_offset(
535534
:param offsets: Either a ``range`` object or a single offset.
536535
"""
537536

538-
return _nodes_at_interval_tree_offset(self._interval_tree, offsets)
537+
return _nodes_at_interval_tree_offset(
538+
self._interval_tree.get(), offsets
539+
)
539540

540541
def code_blocks_on_offset(
541542
self, offsets: typing.Union[int, range]

python/gtirb/lazyintervaltree.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Implements a simple wrapper that lazily initializes and updates an
3+
IntervalTree.
4+
5+
GTIRB uses IntervalTrees to accelerate certain operations. However, these
6+
operations are not always needed for a given GTIRB object or by a given GTIRB
7+
analysis. To prevent scripts that do not need the IntervalTrees from wasting
8+
time updating the data structures, the LazyIntervalTree in this module delays
9+
instantiating or updating the tree. Instead, it queues the updates so they can
10+
be rapidly applied when the script invokes an operation that requires an
11+
up-to-date tree.
12+
"""
13+
14+
import enum
15+
from typing import (
16+
Collection,
17+
Generic,
18+
Iterator,
19+
List,
20+
Optional,
21+
Protocol,
22+
Tuple,
23+
TypeVar,
24+
)
25+
26+
from intervaltree import Interval, IntervalTree
27+
28+
_K = TypeVar("_K")
29+
_Kco = TypeVar("_Kco", covariant=True)
30+
_V = TypeVar("_V")
31+
32+
33+
class _EventType(enum.Enum):
34+
"""Whether an interval is to be added or discarded."""
35+
36+
ADDED = enum.auto()
37+
DISCARDED = enum.auto()
38+
39+
40+
class IntervalBuilder(Protocol[_Kco, _V]):
41+
"""Gets an interval for certain values.
42+
43+
If no interval is available for a particular value, returns None instead.
44+
"""
45+
46+
def __call__(self, value: _V, /) -> Optional["Interval[_Kco, _V]"]:
47+
...
48+
49+
50+
class LazyIntervalTree(Generic[_K, _V]):
51+
"""Simple wrapper to lazily initialize and update an IntervalTree.
52+
53+
The underlying IntervalTree can be retrieved by calling get(). This will
54+
ensure that the tree is up-to-date with all intermediate modifications
55+
before returning it.
56+
57+
In many algorithms, the tree may receive large numbers of modifications,
58+
adding and removing the same intervals several times before querying. In
59+
these cases, it may be faster to rebuild the tree from scratch rather than
60+
perform all of the intermediate modifications. For this reason, get() is
61+
not guaranteed to always return the same tree object. That is, the tree
62+
returned by get() should not be cached; calling get() may return a new tree
63+
rather than updating the tree it returned previously.
64+
"""
65+
66+
def __init__(
67+
self,
68+
values: Collection[_V],
69+
make_interval: IntervalBuilder[_K, _V],
70+
):
71+
"""Create a new lazy tree.
72+
73+
:param values: collection of values from which the tree can be rebuilt
74+
:param make_interval: callable to get an interval for a value
75+
"""
76+
self._interval_index: Optional["IntervalTree[_K, _V]"] = None
77+
self._interval_events: List[Tuple[_EventType, "Interval[_K, _V]"]] = []
78+
self._value_collection = values
79+
self._make_interval = make_interval
80+
81+
def add(self, value: _V) -> None:
82+
"""Add a value to the tree."""
83+
interval = self._make_interval(value)
84+
if interval is not None:
85+
self._interval_events.append((_EventType.ADDED, interval))
86+
87+
def discard(self, value: _V) -> None:
88+
"""Remove a value from the tree.
89+
90+
Does nothing if the interval with that value is not present.
91+
"""
92+
interval = self._make_interval(value)
93+
if interval is not None:
94+
self._interval_events.append((_EventType.DISCARDED, interval))
95+
96+
def get(self) -> "IntervalTree[_K, _V]":
97+
"""Get the most up-to-date tree reflecting all pending updates."""
98+
99+
def intervals() -> Iterator["Interval[_K, _V]"]:
100+
for value in self._value_collection:
101+
interval = self._make_interval(value)
102+
if interval:
103+
yield interval
104+
105+
if self._interval_index is None:
106+
self._interval_index = IntervalTree(intervals())
107+
elif len(self._value_collection) <= len(self._interval_events):
108+
# Constructing a new tree involves one update for each value.
109+
self._interval_index = IntervalTree(intervals())
110+
else:
111+
# There are fewer updates than constructing a new tree would use.
112+
for event, interval in self._interval_events:
113+
if event == _EventType.ADDED:
114+
self._interval_index.add(interval)
115+
else:
116+
self._interval_index.discard(interval)
117+
self._interval_events.clear()
118+
return self._interval_index

python/gtirb/section.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
from enum import Enum
44
from uuid import UUID
55

6-
from intervaltree import IntervalTree
7-
86
from .block import ByteBlock, CodeBlock, DataBlock
97
from .byteinterval import ByteInterval, SymbolicExpressionElement
8+
from .lazyintervaltree import LazyIntervalTree
109
from .node import Node, _NodeMessage
1110
from .proto import Section_pb2
1211
from .util import (
@@ -102,24 +101,27 @@ def __init__(
102101
"""
103102

104103
super().__init__(uuid)
105-
self._interval_index: "IntervalTree[int,ByteInterval]" = IntervalTree()
106104
self._module: typing.Optional["Module"] = None
107105
self.name = name
108-
self.byte_intervals = Section._ByteIntervalSet(self, byte_intervals)
106+
107+
# Both byte_intervals and _interval_index must exist before adding any
108+
# intervals.
109+
self.byte_intervals = Section._ByteIntervalSet(self)
110+
self._interval_index = LazyIntervalTree[int, ByteInterval](
111+
self.byte_intervals, _address_interval
112+
)
113+
self.byte_intervals.update(byte_intervals)
114+
109115
self.flags = set(flags)
110116

111117
# Use the property setter to ensure correct invariants.
112118
self.module = module
113119

114120
def _index_add(self, byte_interval: ByteInterval) -> None:
115-
address_interval = _address_interval(byte_interval)
116-
if address_interval:
117-
self._interval_index.add(address_interval)
121+
self._interval_index.add(byte_interval)
118122

119123
def _index_discard(self, byte_interval: ByteInterval) -> None:
120-
address_interval = _address_interval(byte_interval)
121-
if address_interval:
122-
self._interval_index.discard(address_interval)
124+
self._interval_index.discard(byte_interval)
123125

124126
@classmethod
125127
def _decode_protobuf(
@@ -233,8 +235,9 @@ def address(self) -> typing.Optional[int]:
233235
size, so it will be ``None`` in that case.
234236
"""
235237

236-
if 0 < len(self._interval_index) == len(self.byte_intervals):
237-
return self._interval_index.begin()
238+
index = self._interval_index.get()
239+
if 0 < len(index) == len(self.byte_intervals):
240+
return index.begin()
238241

239242
return None
240243

@@ -251,8 +254,9 @@ def size(self) -> typing.Optional[int]:
251254
it has no address or size, so it will be ``None`` in that case.
252255
"""
253256

254-
if 0 < len(self._interval_index) == len(self.byte_intervals):
255-
return self._interval_index.span() - 1
257+
index = self._interval_index.get()
258+
if 0 < len(index) == len(self.byte_intervals):
259+
return index.span() - 1
256260

257261
return None
258262

@@ -265,7 +269,7 @@ def byte_intervals_on(
265269
:param addrs: Either a ``range`` object or a single address.
266270
"""
267271

268-
return _nodes_on_interval_tree(self._interval_index, addrs)
272+
return _nodes_on_interval_tree(self._interval_index.get(), addrs)
269273

270274
def byte_intervals_at(
271275
self, addrs: typing.Union[int, range]
@@ -276,7 +280,7 @@ def byte_intervals_at(
276280
:param addrs: Either a ``range`` object or a single address.
277281
"""
278282

279-
return _nodes_at_interval_tree(self._interval_index, addrs)
283+
return _nodes_at_interval_tree(self._interval_index.get(), addrs)
280284

281285
def byte_blocks_on(
282286
self, addrs: typing.Union[int, range]

python/stubs/intervaltree/interval.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Generic, TypeVar
22

3-
PointT = TypeVar("PointT")
4-
DataT = TypeVar("DataT")
3+
PointT = TypeVar("PointT", covariant=True)
4+
DataT = TypeVar("DataT", covariant=True)
55

66
class Interval(Generic[PointT, DataT]):
77
begin: PointT

python/tests/test_blocks_at_offset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,32 @@ class BlocksAtOffsetTests(unittest.TestCase):
1010
def test_blocks_at_offset_simple(self):
1111
ir, m, s, bi = create_interval_etc(address=None, size=4)
1212

13+
# Ensure we always have a couple blocks in the index beyond what we
14+
# are querying so that we don't just rebuild the tree from scratch
15+
# every time.
1316
code_block = gtirb.CodeBlock(offset=0, size=1, byte_interval=bi)
1417
code_block2 = gtirb.CodeBlock(offset=1, size=1, byte_interval=bi)
18+
code_block3 = gtirb.CodeBlock(offset=2, size=1, byte_interval=bi)
1519

1620
found = set(bi.byte_blocks_at_offset(0))
1721
self.assertEqual(found, {code_block})
1822

1923
# Change the offset to verify we update the index
20-
code_block.offset = 2
24+
code_block.offset = 3
2125
found = set(bi.byte_blocks_at_offset(0))
2226
self.assertEqual(found, set())
2327

24-
found = set(bi.byte_blocks_at_offset(2))
28+
found = set(bi.byte_blocks_at_offset(3))
2529
self.assertEqual(found, {code_block})
2630

2731
# Discard the block to verify we update the index
2832
bi.blocks.discard(code_block)
29-
found = set(bi.byte_blocks_at_offset(2))
33+
found = set(bi.byte_blocks_at_offset(3))
3034
self.assertEqual(found, set())
3135

3236
# Now add it back to verify we update the index
3337
bi.blocks.add(code_block)
34-
found = set(bi.byte_blocks_at_offset(2))
38+
found = set(bi.byte_blocks_at_offset(3))
3539
self.assertEqual(found, {code_block})
3640

3741
def test_blocks_at_offset_overlapping(self):

0 commit comments

Comments
 (0)