Skip to content
64 changes: 61 additions & 3 deletions docs/notebooks/advanced_sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,36 @@
"source": [
"from lightcurvelynx.math_nodes.given_sampler import GivenValueList\n",
"\n",
"brightness_dist = GivenValueList([18.0, 20.0, 22.0])\n",
"brightness_dist = GivenValueList([18.0, 20.0, 22.0, 25.0])\n",
"model = ConstantSEDModel(brightness=brightness_dist, node_label=\"test\")\n",
"state = model.sample_parameters(num_samples=3)\n",
"print(state[\"test\"][\"brightness\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `GivenValueList` is the only stateful parameterized node. For testing purposes if you query the node multiple times, it will give the next unsampled items from the list."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"state = model.sample_parameters(num_samples=1)\n",
"print(state[\"test\"][\"brightness\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Because it is stateful, the `GivenValueList` does **not** support parallel execution. The simulation will fail with an error if a `GivenValueList` is used."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -245,7 +269,41 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The `in_order` flag tells the node whether to extract the rows in order (`True`) or randomly with replacement (`False`).\n",
"The `in_order` flag tells the node whether to extract the rows in order (`True`) or randomly with replacement (`False`). Note that the `TableSampler` is **not** stateful. If called multiple times (with `in_order=True`), it will return the first N rows from the table each time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"state = table_node.sample_parameters(num_samples=3)\n",
"print(state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If users want to perform multiple sequential simulations using different parts of the table, then they will need to use the ``sample_offset`` parameter."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"state = table_node.sample_parameters(num_samples=3, sample_offset=3)\n",
"print(state)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, most users should not need to ever set ``sample_offset`` directly. If you are running a parallelized simulation, the software sets and uses this value behind the scenes to ensure that every worker is operating on a different part of the table.\n",
"\n",
"As with other node types, we can use the dot notation to use these values as input for other models. For example, let’s assume that the 'B' column corresponds to Brightness, 'A' corresponds to RA, and 'C' is not used."
]
Expand Down Expand Up @@ -284,7 +342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
"version": "3.13.8"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion src/lightcurvelynx/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GraphState:
Default: 0
"""

def __init__(self, num_samples=1, sample_offset=0):
def __init__(self, num_samples=1, *, sample_offset=0):
if num_samples < 1:
raise ValueError(
f"Invalid number of samples for GraphState ({num_samples}). Must be a positive integer."
Expand Down
81 changes: 43 additions & 38 deletions src/lightcurvelynx/math_nodes/given_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class GivenValueList(FunctionNode):

Note
----
This node does not support parallel sampling. It will return the same
sequence of values for for each shard.
This is a stateful node that keeps track of the next index to return
that cannot be used in parallel sampling.

Attributes
----------
Expand All @@ -89,14 +89,7 @@ def __init__(self, values, **kwargs):
super().__init__(self._non_func, **kwargs)

def __getstate__(self):
"""We override the default pickling behavior to add a warning because we do
not correctly support parallel sampling of in-order values.
"""
warnings.warn(
"GivenValueList does not support distributed computation. Each shard will "
"return the same sequence of values. We recommend using the GivenValueSampler."
)
return self.__dict__.copy()
raise RuntimeError("GivenValueList cannot be pickled. This node does not support parallel sampling.")

def reset(self):
"""Reset the next index to use."""
Expand All @@ -122,19 +115,29 @@ def compute(self, graph_state, rng_info=None, **kwargs):
The result of the computation. This return value is provided so that testing
functions can easily access the results.
"""
sample_ind = self.next_ind
if graph_state.sample_offset is not None:
sample_ind += graph_state.sample_offset

if graph_state.num_samples == 1:
if self.next_ind >= len(self.values):
raise IndexError()
if sample_ind >= len(self.values):
raise IndexError(
f"GivenValueList ran out of entries to sample. Index {sample_ind} out "
f"of bounds for a list with {len(self.values)} entries."
)

results = self.values[self.next_ind]
results = self.values[sample_ind]
self.next_ind += 1
else:
end_ind = self.next_ind + graph_state.num_samples
end_ind = sample_ind + graph_state.num_samples
if end_ind > len(self.values):
raise IndexError()
raise IndexError(
f"GivenValueList ran out of entries to sample. Index {sample_ind} out "
f"of bounds for a list with {len(self.values)} entries."
)

results = self.values[self.next_ind : end_ind]
self.next_ind = end_ind
results = self.values[sample_ind:end_ind]
self.next_ind += graph_state.num_samples

# Save and return the results.
self._save_results(results, graph_state)
Expand Down Expand Up @@ -246,8 +249,9 @@ class TableSampler(FunctionNode):

Note
----
This node does not support "in order" parallel sampling. It will
return the same sequence of values for each shard.
This is NOT a stateful node. When in_order=True the node will always
return the first N rows of the table, where N is the number of samples
requested.

Parameters
----------
Expand All @@ -266,15 +270,13 @@ class TableSampler(FunctionNode):
in_order : bool
Return the given data in order of the rows (True). If False, performs
random sampling with replacement. Default: False
next_ind : int
The next index to sample for in order sampling.
num_values : int
The total number of items from which to draw the data.
"""

def __init__(self, data, in_order=False, **kwargs):
self.next_ind = 0
self.in_order = in_order
self._last_start_index = -1

if isinstance(data, dict):
self.data = Table(data)
Expand Down Expand Up @@ -304,18 +306,6 @@ def __init__(self, data, in_order=False, **kwargs):
"The index of the selected row in the table.",
)

def __getstate__(self):
"""We override the default pickling behavior to add a warning when in_order is true
because we do not correctly support parallel sampling of in-order values.
"""
if self.in_order:
warnings.warn(
"TableSampler with in_order=True does not support distributed computation. "
"Each shard will return the same sequence of values. We recommend setting "
"in_order=False for distributed sampling."
)
return self.__dict__.copy()

def __len__(self):
"""Return the number of items in the table."""
return self._num_values
Expand Down Expand Up @@ -346,13 +336,28 @@ def compute(self, graph_state, rng_info=None, **kwargs):
"""
# Compute the indices to sample.
if self.in_order:
start_ind = 0
if graph_state.sample_offset is not None:
start_ind += graph_state.sample_offset

if start_ind == self._last_start_index:
warnings.warn(
"TableSampler in_order sampling called multiple times with the same sample_offset. "
"This may indicate unintended behavior, because the same parameter values are used "
"multiple times instead of iterating over the table. Consider to set different "
"sample_offset values for different objects or chunks."
)
self._last_start_index = start_ind

# Check that we have enough points left to sample.
end_index = self.next_ind + graph_state.num_samples
end_index = start_ind + graph_state.num_samples
if end_index > len(self.data):
raise IndexError()
raise IndexError(
f"TableSampler ran out of entries to sample. Index {end_index} out "
f"of bounds for a table with {len(self.data)} entries."
)

sample_inds = np.arange(self.next_ind, end_index)
self.next_ind = end_index
sample_inds = np.arange(start_ind, end_index)
else:
sample_inds = self.get_param(graph_state, "selected_table_index")

Expand Down
91 changes: 58 additions & 33 deletions tests/lightcurvelynx/math_nodes/test_given_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ def test_binary_sampler():
with pytest.raises(ValueError):
_ = BinarySampler(1.5)

# We can pickle the BinarySampler.
node_str = pickle.dumps(binary_node)
assert node_str is not None


def test_given_value_list():
"""Test that we can retrieve numbers from a GivenValueList."""
Expand Down Expand Up @@ -87,11 +83,32 @@ def test_given_value_list():
with pytest.raises(ValueError):
_ = GivenValueList([])

# We should not pickle a GivenValueList that is in order.
with pytest.warns(UserWarning):
# GivenValueList cannot be used in distributed computation.
with pytest.raises(RuntimeError):
_ = pickle.dumps(given_node)


def test_given_value_list_offset():
"""Test that we can retrieve numbers from a GivenValueList when using a sample offset."""
given_node = GivenValueList([1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5])

# Check that we generate the correct result and save it in the GraphState.
state1 = GraphState(num_samples=2, sample_offset=3)
results = given_node.compute(state1)
assert np.array_equal(results, [2.5, 3.0])
assert np.array_equal(given_node.get_param(state1, "function_node_result"), [2.5, 3.0])

state2 = GraphState(num_samples=1, sample_offset=3)
results = given_node.compute(state2)
assert results == -1.0
assert given_node.get_param(state2, "function_node_result") == -1.0

# Check that GivenValueList raises an error when it has run out of samples.
state3 = GraphState(num_samples=2, sample_offset=3)
with pytest.raises(IndexError):
_ = given_node.compute(state3)


def test_test_given_value_list_compound():
"""Test that we can use the GivenValueList as input into another node."""
values = [1.0, 1.5, 2.0, 2.5, 3.0, -1.0, 3.5, 4.0, 10.0, -2.0]
Expand Down Expand Up @@ -151,10 +168,6 @@ def test_given_value_sampler():
with pytest.raises(ValueError):
_ = GivenValueSampler([1, 3, 5], weights=[0.5, 0.5])

# We can pickle the GivenValueSampler.
node_str = pickle.dumps(given_node)
assert node_str is not None


def test_given_value_sampler_int():
"""Test that we can retrieve numbers from a GivenValueSampler representing a range."""
Expand Down Expand Up @@ -240,33 +253,28 @@ def test_table_sampler(test_data_type):
assert np.allclose(state["node"]["B"], [1, 1])
assert np.allclose(state["node"]["C"], [3, 4])

state = table_node.sample_parameters(num_samples=1)
# We can sample a single value. Note that the node is not
# stateful, so it always returns the first N rows. We produce
# a warning if we detect that the same offset is being used
# multiple times.
with pytest.warns(UserWarning):
state = table_node.sample_parameters(num_samples=1)
assert len(state) == 3
assert state["node"]["A"] == 3
assert state["node"]["A"] == 1
assert state["node"]["B"] == 1
assert state["node"]["C"] == 5
assert state["node"]["C"] == 3

state = table_node.sample_parameters(num_samples=4)
# We can sample later values using a forced offset. No warning
# should be produced here since the offset is different.
state = table_node.sample_parameters(num_samples=4, sample_offset=2)
assert len(state) == 3
assert np.allclose(state["node"]["A"], [4, 5, 6, 7])
assert np.allclose(state["node"]["A"], [3, 4, 5, 6])
assert np.allclose(state["node"]["B"], [1, 1, 1, 1])
assert np.allclose(state["node"]["C"], [6, 7, 8, 9])
assert np.allclose(state["node"]["C"], [5, 6, 7, 8])

# We go past the end of the data.
with pytest.raises(IndexError):
_ = table_node.sample_parameters(num_samples=4)

# We can reset and sample from the beginning.
table_node.reset()
state = table_node.sample_parameters(num_samples=2)
assert len(state) == 3
assert np.allclose(state["node"]["A"], [1, 2])
assert np.allclose(state["node"]["B"], [1, 1])
assert np.allclose(state["node"]["C"], [3, 4])

# We should pickle a TableSampler that is in order.
with pytest.warns(UserWarning):
_ = pickle.dumps(table_node)
_ = table_node.sample_parameters(num_samples=100, sample_offset=5)


def test_table_sampler_fail():
Expand All @@ -280,6 +288,27 @@ def test_table_sampler_fail():
_ = TableSampler({"a": [], "b": []})


def test_table_sampler_offset():
"""Test that we can retrieve numbers from a TableSampler with an offset."""
data = {
"A": [1, 2, 3, 4, 5, 6, 7, 8],
"B": [1, 1, 1, 1, 1, 1, 1, 1],
"C": [3, 4, 5, 6, 7, 8, 9, 10],
}

# Create the table sampler from the data.
table_node = TableSampler(data, in_order=True, node_label="node")
state = table_node.sample_parameters(num_samples=2, sample_offset=4)
assert len(state) == 3
assert np.allclose(state["node"]["A"], [5, 6])
assert np.allclose(state["node"]["B"], [1, 1])
assert np.allclose(state["node"]["C"], [7, 8])

# We go past the end of the data.
with pytest.raises(IndexError):
_ = table_node.sample_parameters(num_samples=4, sample_offset=14)


def test_table_sampler_randomized():
"""Test that we can retrieve numbers from a TableSampler."""
raw_data_dict = {
Expand Down Expand Up @@ -320,7 +349,3 @@ def test_table_sampler_randomized():

# We always sample consistent ROWS of a and b.
assert np.all(b_vals - a_vals == 1)

# We can pickle a randomized TableSampler.
node_str = pickle.dumps(table_node)
assert node_str is not None
Loading