Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/lightcurvelynx/graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
of a host galaxy.
"""

import copy

import numpy as np
from astropy.io import ascii
from astropy.table import Table
Expand Down Expand Up @@ -187,15 +189,19 @@ def copy(self):
"""
new_state = GraphState(num_samples=self.num_samples)
new_state.num_parameters = self.num_parameters
new_state.sample_offset = self.sample_offset
for node_name, node_vars in self.states.items():
new_state.states[node_name] = {}
for var_name, var_value in node_vars.items():
if self.num_samples == 1:
new_state.states[node_name][var_name] = var_value
else:
new_state.states[node_name][var_name] = var_value.copy()
for node_name, fixed_vars in self.fixed_vars.items():
new_state.fixed_vars[node_name] = fixed_vars.copy()

# Copy over the sets of fixed variables. This is a single set of strings
# per node, so we just need to copy the sets.
for node_name, var_set in self.fixed_vars.items():
new_state.fixed_vars[node_name] = copy.deepcopy(var_set)
return new_state

@staticmethod
Expand Down Expand Up @@ -502,13 +508,20 @@ def extract_single_sample(self, sample_num):
# Make a copy of the GraphState with exactly one sample.
new_state = GraphState(1)
new_state.num_parameters = self.num_parameters
new_state.sample_offset = self.sample_offset
for node_name in self.states:
new_state.states[node_name] = {}
for var_name, value in self.states[node_name].items():
if self.num_samples == 1:
new_state.states[node_name][var_name] = value
else:
new_state.states[node_name][var_name] = value[sample_num]

# Copy over the sets of fixed variables. This is a single set of strings
# per node, so we just need to copy the sets.
for node_name, var_set in self.fixed_vars.items():
new_state.fixed_vars[node_name] = copy.deepcopy(var_set)

return new_state

def extract_parameters(self, params):
Expand Down
22 changes: 21 additions & 1 deletion tests/lightcurvelynx/test_graph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,16 @@ def test_create_single_sample_graph_state():
new_state = state.extract_single_sample(0)
assert len(new_state) == 3
assert new_state.num_samples == 1
assert new_state.sample_offset == 0
assert state["a"]["v1"] == 1.0
assert state["a"]["v2"] == 2.0
assert state["b"]["v1"] == 3.0

# Both nodes are in the fixed vars (empty sets).
assert len(new_state.fixed_vars) == 2
assert "a" in new_state.fixed_vars
assert "b" in new_state.fixed_vars

# We can overwrite settings.
state.set("a", "v1", 10.0)
assert len(state) == 3
Expand Down Expand Up @@ -284,6 +290,15 @@ def test_graph_state_copy():
state.set("b", "v1", 3.0)

state2 = state.copy()

# Test that we copied over the meta-data
assert state2.num_parameters == 3
assert state2.sample_offset == 0
assert len(state2.fixed_vars) == 2 # empty set for each node
assert "a" in state2.fixed_vars
assert "b" in state2.fixed_vars

# Test with single values.
state2.set("a", "v1", 10.0)
state2.set("a", "v2", 20.0)
state2.set("b", "v1", 30.0)
Expand All @@ -299,12 +314,17 @@ def test_graph_state_copy():
assert state2["b"]["v1"] == 30.0

# Test with arrays.
state = GraphState(3)
state = GraphState(3, sample_offset=2)
state.set("a", "v1", np.array([1.0, 2.0, 3.0]))
state.set("a", "v2", np.array([2.0, 3.0, 4.0]))
state.set("b", "v1", np.array([3.0, 4.0, 5.0]))

# Test that we copied over the meta-data
state2 = state.copy()
assert state2.num_parameters == 3
assert state2.sample_offset == 2
assert len(state2.fixed_vars) == 2 # empty set for each node

state2["a.v1"][1] = 10.0
state2["a.v2"][0] = 20.0
state2["b.v1"][2] = 30.0
Expand Down