Skip to content

Commit ad010e5

Browse files
authored
Allow stateful samplers to work in a distributed context (#623)
* Add a sample offset * Allow stateful samplers to work in parallel runs * Readd parallel run warning * Make TableSampler stateless * Update documentation and tests * Address PR comments
1 parent f065d17 commit ad010e5

File tree

5 files changed

+188
-91
lines changed

5 files changed

+188
-91
lines changed

docs/notebooks/advanced_sampling.ipynb

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,36 @@
110110
"source": [
111111
"from lightcurvelynx.math_nodes.given_sampler import GivenValueList\n",
112112
"\n",
113-
"brightness_dist = GivenValueList([18.0, 20.0, 22.0])\n",
113+
"brightness_dist = GivenValueList([18.0, 20.0, 22.0, 25.0])\n",
114114
"model = ConstantSEDModel(brightness=brightness_dist, node_label=\"test\")\n",
115115
"state = model.sample_parameters(num_samples=3)\n",
116116
"print(state[\"test\"][\"brightness\"])"
117117
]
118118
},
119+
{
120+
"cell_type": "markdown",
121+
"metadata": {},
122+
"source": [
123+
"The `GivenValueList` is the only stateful parameterized node. For testing purposes if you query the node multiple times, it will give the next unsampled items from the list."
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"metadata": {},
130+
"outputs": [],
131+
"source": [
132+
"state = model.sample_parameters(num_samples=1)\n",
133+
"print(state[\"test\"][\"brightness\"])"
134+
]
135+
},
136+
{
137+
"cell_type": "markdown",
138+
"metadata": {},
139+
"source": [
140+
"Because it is stateful, the `GivenValueList` does **not** support parallel execution. The simulation will fail with an error if a `GivenValueList` is used."
141+
]
142+
},
119143
{
120144
"cell_type": "markdown",
121145
"metadata": {},
@@ -245,7 +269,41 @@
245269
"cell_type": "markdown",
246270
"metadata": {},
247271
"source": [
248-
"The `in_order` flag tells the node whether to extract the rows in order (`True`) or randomly with replacement (`False`).\n",
272+
"The `in_order` flag tells the node whether to extract the rows in order (`True`) or randomly with replacement (`False`). Note that the `TableSampler` is **not** stateful. If called multiple times (with `in_order=True`), it will return the first N rows from the table each time."
273+
]
274+
},
275+
{
276+
"cell_type": "code",
277+
"execution_count": null,
278+
"metadata": {},
279+
"outputs": [],
280+
"source": [
281+
"state = table_node.sample_parameters(num_samples=3)\n",
282+
"print(state)"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"If users want to perform multiple sequential simulations using different parts of the table, then they will need to use the ``sample_offset`` parameter."
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": null,
295+
"metadata": {},
296+
"outputs": [],
297+
"source": [
298+
"state = table_node.sample_parameters(num_samples=3, sample_offset=3)\n",
299+
"print(state)"
300+
]
301+
},
302+
{
303+
"cell_type": "markdown",
304+
"metadata": {},
305+
"source": [
306+
"However, most users should not need to ever set ``sample_offset`` directly. If you are running a parallelized simulation, the software sets and uses this value behind the scenes to ensure that every worker is operating on a different part of the table.\n",
249307
"\n",
250308
"As with other node types, we can use the dot notation to use these values as input for other models. For example, let’s assume that the 'B' column corresponds to Brightness, 'A' corresponds to RA, and 'C' is not used."
251309
]
@@ -284,7 +342,7 @@
284342
"name": "python",
285343
"nbconvert_exporter": "python",
286344
"pygments_lexer": "ipython3",
287-
"version": "3.10.4"
345+
"version": "3.13.8"
288346
}
289347
},
290348
"nbformat": 4,

src/lightcurvelynx/graph_state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class GraphState:
5555
Default: 0
5656
"""
5757

58-
def __init__(self, num_samples=1, sample_offset=0):
58+
def __init__(self, num_samples=1, *, sample_offset=0):
5959
if num_samples < 1:
6060
raise ValueError(
6161
f"Invalid number of samples for GraphState ({num_samples}). Must be a positive integer."

src/lightcurvelynx/math_nodes/given_sampler.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ class GivenValueList(FunctionNode):
6969
7070
Note
7171
----
72-
This node does not support parallel sampling. It will return the same
73-
sequence of values for for each shard.
72+
This is a stateful node that keeps track of the next index to return
73+
that cannot be used in parallel sampling.
7474
7575
Attributes
7676
----------
@@ -89,14 +89,7 @@ def __init__(self, values, **kwargs):
8989
super().__init__(self._non_func, **kwargs)
9090

9191
def __getstate__(self):
92-
"""We override the default pickling behavior to add a warning because we do
93-
not correctly support parallel sampling of in-order values.
94-
"""
95-
warnings.warn(
96-
"GivenValueList does not support distributed computation. Each shard will "
97-
"return the same sequence of values. We recommend using the GivenValueSampler."
98-
)
99-
return self.__dict__.copy()
92+
raise RuntimeError("GivenValueList cannot be pickled. This node does not support parallel sampling.")
10093

10194
def reset(self):
10295
"""Reset the next index to use."""
@@ -122,19 +115,29 @@ def compute(self, graph_state, rng_info=None, **kwargs):
122115
The result of the computation. This return value is provided so that testing
123116
functions can easily access the results.
124117
"""
118+
sample_ind = self.next_ind
119+
if graph_state.sample_offset is not None:
120+
sample_ind += graph_state.sample_offset
121+
125122
if graph_state.num_samples == 1:
126-
if self.next_ind >= len(self.values):
127-
raise IndexError()
123+
if sample_ind >= len(self.values):
124+
raise IndexError(
125+
f"GivenValueList ran out of entries to sample. Index {sample_ind} out "
126+
f"of bounds for a list with {len(self.values)} entries."
127+
)
128128

129-
results = self.values[self.next_ind]
129+
results = self.values[sample_ind]
130130
self.next_ind += 1
131131
else:
132-
end_ind = self.next_ind + graph_state.num_samples
132+
end_ind = sample_ind + graph_state.num_samples
133133
if end_ind > len(self.values):
134-
raise IndexError()
134+
raise IndexError(
135+
f"GivenValueList ran out of entries to sample. Index {sample_ind} out "
136+
f"of bounds for a list with {len(self.values)} entries."
137+
)
135138

136-
results = self.values[self.next_ind : end_ind]
137-
self.next_ind = end_ind
139+
results = self.values[sample_ind:end_ind]
140+
self.next_ind += graph_state.num_samples
138141

139142
# Save and return the results.
140143
self._save_results(results, graph_state)
@@ -246,8 +249,9 @@ class TableSampler(FunctionNode):
246249
247250
Note
248251
----
249-
This node does not support "in order" parallel sampling. It will
250-
return the same sequence of values for each shard.
252+
This is NOT a stateful node. When in_order=True the node will always
253+
return the first N rows of the table, where N is the number of samples
254+
requested.
251255
252256
Parameters
253257
----------
@@ -266,15 +270,13 @@ class TableSampler(FunctionNode):
266270
in_order : bool
267271
Return the given data in order of the rows (True). If False, performs
268272
random sampling with replacement. Default: False
269-
next_ind : int
270-
The next index to sample for in order sampling.
271273
num_values : int
272274
The total number of items from which to draw the data.
273275
"""
274276

275277
def __init__(self, data, in_order=False, **kwargs):
276-
self.next_ind = 0
277278
self.in_order = in_order
279+
self._last_start_index = -1
278280

279281
if isinstance(data, dict):
280282
self.data = Table(data)
@@ -304,18 +306,6 @@ def __init__(self, data, in_order=False, **kwargs):
304306
"The index of the selected row in the table.",
305307
)
306308

307-
def __getstate__(self):
308-
"""We override the default pickling behavior to add a warning when in_order is true
309-
because we do not correctly support parallel sampling of in-order values.
310-
"""
311-
if self.in_order:
312-
warnings.warn(
313-
"TableSampler with in_order=True does not support distributed computation. "
314-
"Each shard will return the same sequence of values. We recommend setting "
315-
"in_order=False for distributed sampling."
316-
)
317-
return self.__dict__.copy()
318-
319309
def __len__(self):
320310
"""Return the number of items in the table."""
321311
return self._num_values
@@ -346,13 +336,28 @@ def compute(self, graph_state, rng_info=None, **kwargs):
346336
"""
347337
# Compute the indices to sample.
348338
if self.in_order:
339+
start_ind = 0
340+
if graph_state.sample_offset is not None:
341+
start_ind += graph_state.sample_offset
342+
343+
if start_ind == self._last_start_index:
344+
warnings.warn(
345+
"TableSampler in_order sampling called multiple times with the same sample_offset. "
346+
"This may indicate unintended behavior, because the same parameter values are used "
347+
"multiple times instead of iterating over the table. Consider to set different "
348+
"sample_offset values for different objects or chunks."
349+
)
350+
self._last_start_index = start_ind
351+
349352
# Check that we have enough points left to sample.
350-
end_index = self.next_ind + graph_state.num_samples
353+
end_index = start_ind + graph_state.num_samples
351354
if end_index > len(self.data):
352-
raise IndexError()
355+
raise IndexError(
356+
f"TableSampler ran out of entries to sample. Index {end_index} out "
357+
f"of bounds for a table with {len(self.data)} entries."
358+
)
353359

354-
sample_inds = np.arange(self.next_ind, end_index)
355-
self.next_ind = end_index
360+
sample_inds = np.arange(start_ind, end_index)
356361
else:
357362
sample_inds = self.get_param(graph_state, "selected_table_index")
358363

tests/lightcurvelynx/math_nodes/test_given_sampler.py

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ def test_binary_sampler():
4343
with pytest.raises(ValueError):
4444
_ = BinarySampler(1.5)
4545

46-
# We can pickle the BinarySampler.
47-
node_str = pickle.dumps(binary_node)
48-
assert node_str is not None
49-
5046

5147
def test_given_value_list():
5248
"""Test that we can retrieve numbers from a GivenValueList."""
@@ -87,11 +83,32 @@ def test_given_value_list():
8783
with pytest.raises(ValueError):
8884
_ = GivenValueList([])
8985

90-
# We should not pickle a GivenValueList that is in order.
91-
with pytest.warns(UserWarning):
86+
# GivenValueList cannot be used in distributed computation.
87+
with pytest.raises(RuntimeError):
9288
_ = pickle.dumps(given_node)
9389

9490

91+
def test_given_value_list_offset():
92+
"""Test that we can retrieve numbers from a GivenValueList when using a sample offset."""
93+
given_node = GivenValueList([1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5])
94+
95+
# Check that we generate the correct result and save it in the GraphState.
96+
state1 = GraphState(num_samples=2, sample_offset=3)
97+
results = given_node.compute(state1)
98+
assert np.array_equal(results, [2.5, 3.0])
99+
assert np.array_equal(given_node.get_param(state1, "function_node_result"), [2.5, 3.0])
100+
101+
state2 = GraphState(num_samples=1, sample_offset=3)
102+
results = given_node.compute(state2)
103+
assert results == -1.0
104+
assert given_node.get_param(state2, "function_node_result") == -1.0
105+
106+
# Check that GivenValueList raises an error when it has run out of samples.
107+
state3 = GraphState(num_samples=2, sample_offset=3)
108+
with pytest.raises(IndexError):
109+
_ = given_node.compute(state3)
110+
111+
95112
def test_test_given_value_list_compound():
96113
"""Test that we can use the GivenValueList as input into another node."""
97114
values = [1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5, 4.0, 10.0, -2.0]
@@ -151,10 +168,6 @@ def test_given_value_sampler():
151168
with pytest.raises(ValueError):
152169
_ = GivenValueSampler([1, 3, 5], weights=[0.5, 0.5])
153170

154-
# We can pickle the GivenValueSampler.
155-
node_str = pickle.dumps(given_node)
156-
assert node_str is not None
157-
158171

159172
def test_given_value_sampler_int():
160173
"""Test that we can retrieve numbers from a GivenValueSampler representing a range."""
@@ -240,33 +253,28 @@ def test_table_sampler(test_data_type):
240253
assert np.allclose(state["node"]["B"], [1, 1])
241254
assert np.allclose(state["node"]["C"], [3, 4])
242255

243-
state = table_node.sample_parameters(num_samples=1)
256+
# We can sample a single value. Note that the node is not
257+
# stateful, so it always returns the first N rows. We produce
258+
# a warning if we detect that the same offset is being used
259+
# multiple times.
260+
with pytest.warns(UserWarning):
261+
state = table_node.sample_parameters(num_samples=1)
244262
assert len(state) == 3
245-
assert state["node"]["A"] == 3
263+
assert state["node"]["A"] == 1
246264
assert state["node"]["B"] == 1
247-
assert state["node"]["C"] == 5
265+
assert state["node"]["C"] == 3
248266

249-
state = table_node.sample_parameters(num_samples=4)
267+
# We can sample later values using a forced offset. No warning
268+
# should be produced here since the offset is different.
269+
state = table_node.sample_parameters(num_samples=4, sample_offset=2)
250270
assert len(state) == 3
251-
assert np.allclose(state["node"]["A"], [4, 5, 6, 7])
271+
assert np.allclose(state["node"]["A"], [3, 4, 5, 6])
252272
assert np.allclose(state["node"]["B"], [1, 1, 1, 1])
253-
assert np.allclose(state["node"]["C"], [6, 7, 8, 9])
273+
assert np.allclose(state["node"]["C"], [5, 6, 7, 8])
254274

255275
# We go past the end of the data.
256276
with pytest.raises(IndexError):
257-
_ = table_node.sample_parameters(num_samples=4)
258-
259-
# We can reset and sample from the beginning.
260-
table_node.reset()
261-
state = table_node.sample_parameters(num_samples=2)
262-
assert len(state) == 3
263-
assert np.allclose(state["node"]["A"], [1, 2])
264-
assert np.allclose(state["node"]["B"], [1, 1])
265-
assert np.allclose(state["node"]["C"], [3, 4])
266-
267-
# We should pickle a TableSampler that is in order.
268-
with pytest.warns(UserWarning):
269-
_ = pickle.dumps(table_node)
277+
_ = table_node.sample_parameters(num_samples=100, sample_offset=5)
270278

271279

272280
def test_table_sampler_fail():
@@ -280,6 +288,27 @@ def test_table_sampler_fail():
280288
_ = TableSampler({"a": [], "b": []})
281289

282290

291+
def test_table_sampler_offset():
292+
"""Test that we can retrieve numbers from a TableSampler with an offset."""
293+
data = {
294+
"A": [1, 2, 3, 4, 5, 6, 7, 8],
295+
"B": [1, 1, 1, 1, 1, 1, 1, 1],
296+
"C": [3, 4, 5, 6, 7, 8, 9, 10],
297+
}
298+
299+
# Create the table sampler from the data.
300+
table_node = TableSampler(data, in_order=True, node_label="node")
301+
state = table_node.sample_parameters(num_samples=2, sample_offset=4)
302+
assert len(state) == 3
303+
assert np.allclose(state["node"]["A"], [5, 6])
304+
assert np.allclose(state["node"]["B"], [1, 1])
305+
assert np.allclose(state["node"]["C"], [7, 8])
306+
307+
# We go past the end of the data.
308+
with pytest.raises(IndexError):
309+
_ = table_node.sample_parameters(num_samples=4, sample_offset=14)
310+
311+
283312
def test_table_sampler_randomized():
284313
"""Test that we can retrieve numbers from a TableSampler."""
285314
raw_data_dict = {
@@ -320,7 +349,3 @@ def test_table_sampler_randomized():
320349

321350
# We always sample consistent ROWS of a and b.
322351
assert np.all(b_vals - a_vals == 1)
323-
324-
# We can pickle a randomized TableSampler.
325-
node_str = pickle.dumps(table_node)
326-
assert node_str is not None

0 commit comments

Comments
 (0)