From a8f5e08cebf4abab4ff45978f58c69c1c2627f98 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:20:33 -0500 Subject: [PATCH 1/2] Fix a bug in extract_single_sample --- src/lightcurvelynx/graph_state.py | 11 +++++++++-- tests/lightcurvelynx/test_graph_state.py | 22 +++++++++++++++++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/lightcurvelynx/graph_state.py b/src/lightcurvelynx/graph_state.py index d89c4b92..43074496 100644 --- a/src/lightcurvelynx/graph_state.py +++ b/src/lightcurvelynx/graph_state.py @@ -187,6 +187,7 @@ def copy(self): """ new_state = GraphState(num_samples=self.num_samples) new_state.num_parameters = self.num_parameters + new_state.sample_offset = self.sample_offset for node_name, node_vars in self.states.items(): new_state.states[node_name] = {} for var_name, var_value in node_vars.items(): @@ -194,8 +195,8 @@ def copy(self): new_state.states[node_name][var_name] = var_value else: new_state.states[node_name][var_name] = var_value.copy() - for node_name, fixed_vars in self.fixed_vars.items(): - new_state.fixed_vars[node_name] = fixed_vars.copy() + for node_name, var_list in self.fixed_vars.items(): + new_state.fixed_vars[node_name] = var_list.copy() return new_state @staticmethod @@ -502,6 +503,7 @@ def extract_single_sample(self, sample_num): # Make a copy of the GraphState with exactly one sample. new_state = GraphState(1) new_state.num_parameters = self.num_parameters + new_state.sample_offset = self.sample_offset for node_name in self.states: new_state.states[node_name] = {} for var_name, value in self.states[node_name].items(): @@ -509,6 +511,11 @@ def extract_single_sample(self, sample_num): new_state.states[node_name][var_name] = value else: new_state.states[node_name][var_name] = value[sample_num] + + # Copy over the fixed vars information. + for node_name, var_list in self.fixed_vars.items(): + new_state.fixed_vars[node_name] = var_list.copy() + return new_state def extract_parameters(self, params): diff --git a/tests/lightcurvelynx/test_graph_state.py b/tests/lightcurvelynx/test_graph_state.py index 2b36faa5..b547f413 100644 --- a/tests/lightcurvelynx/test_graph_state.py +++ b/tests/lightcurvelynx/test_graph_state.py @@ -78,10 +78,16 @@ def test_create_single_sample_graph_state(): new_state = state.extract_single_sample(0) assert len(new_state) == 3 assert new_state.num_samples == 1 + assert new_state.sample_offset == 0 assert state["a"]["v1"] == 1.0 assert state["a"]["v2"] == 2.0 assert state["b"]["v1"] == 3.0 + # Both nodes are in the fixed vars (empty sets). + assert len(new_state.fixed_vars) == 2 + assert "a" in new_state.fixed_vars + assert "b" in new_state.fixed_vars + # We can overwrite settings. state.set("a", "v1", 10.0) assert len(state) == 3 @@ -284,6 +290,15 @@ def test_graph_state_copy(): state.set("b", "v1", 3.0) state2 = state.copy() + + # Test that we copied over the meta-data + assert state2.num_parameters == 3 + assert state2.sample_offset == 0 + assert len(state2.fixed_vars) == 2 # empty set for each node + assert "a" in state2.fixed_vars + assert "b" in state2.fixed_vars + + # Test with single values. state2.set("a", "v1", 10.0) state2.set("a", "v2", 20.0) state2.set("b", "v1", 30.0) @@ -299,12 +314,17 @@ def test_graph_state_copy(): assert state2["b"]["v1"] == 30.0 # Test with arrays. - state = GraphState(3) + state = GraphState(3, sample_offset=2) state.set("a", "v1", np.array([1.0, 2.0, 3.0])) state.set("a", "v2", np.array([2.0, 3.0, 4.0])) state.set("b", "v1", np.array([3.0, 4.0, 5.0])) + # Test that we copied over the meta-data state2 = state.copy() + assert state2.num_parameters == 3 + assert state2.sample_offset == 2 + assert len(state2.fixed_vars) == 2 # empty set for each node + state2["a.v1"][1] = 10.0 state2["a.v2"][0] = 20.0 state2["b.v1"][2] = 30.0 From 8650f8d286e83f36303d1779821544730177bba4 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:34:58 -0500 Subject: [PATCH 2/2] Address PR comments --- src/lightcurvelynx/graph_state.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/lightcurvelynx/graph_state.py b/src/lightcurvelynx/graph_state.py index 43074496..7685bf76 100644 --- a/src/lightcurvelynx/graph_state.py +++ b/src/lightcurvelynx/graph_state.py @@ -24,6 +24,8 @@ of a host galaxy. """ +import copy + import numpy as np from astropy.io import ascii from astropy.table import Table @@ -195,8 +197,11 @@ def copy(self): new_state.states[node_name][var_name] = var_value else: new_state.states[node_name][var_name] = var_value.copy() - for node_name, var_list in self.fixed_vars.items(): - new_state.fixed_vars[node_name] = var_list.copy() + + # Copy over the sets of fixed variables. This is a single set of strings + # per node, so we just need to copy the sets. + for node_name, var_set in self.fixed_vars.items(): + new_state.fixed_vars[node_name] = copy.deepcopy(var_set) return new_state @staticmethod @@ -512,9 +517,10 @@ def extract_single_sample(self, sample_num): else: new_state.states[node_name][var_name] = value[sample_num] - # Copy over the fixed vars information. - for node_name, var_list in self.fixed_vars.items(): - new_state.fixed_vars[node_name] = var_list.copy() + # Copy over the sets of fixed variables. This is a single set of strings + # per node, so we just need to copy the sets. + for node_name, var_set in self.fixed_vars.items(): + new_state.fixed_vars[node_name] = copy.deepcopy(var_set) return new_state