Skip to content

Commit 745117b

Browse files
Jammy2211Jammy2211
authored andcommitted
Fix Sample.kwargs mixed string/tuple key bug
Sample.__init__ previously converted string kwargs to tuple paths only when the key contained a dot. A model with both nested params (e.g. 'ellipses.11.centre.centre_0') and top-level dotless params (e.g. 'dummy_0') produced a mixed-type kwargs dict. Sample.is_path_kwargs then inspected only the first key, misclassified the sample, and parameter_lists_for_paths raised KeyError when looking up dotless keys as single-element tuples. The fix is to drop the "." in key guard so every string key is uniformly converted to a tuple — single-name keys become single-element tuples that match the path produced by model.all_paths. The dict() / database JSON round-trip remains symmetric because dict() joins tuples back to dotted strings on serialize. Reported by Sam in the aggregator-to-database flow.
1 parent e211f9e commit 745117b

5 files changed

Lines changed: 34 additions & 11 deletions

File tree

autofit/non_linear/samples/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
self.log_prior = log_prior
3636
self.weight = weight
3737
self.kwargs = {
38-
tuple(key.split(".")) if isinstance(key, str) and "." in key else key: value
38+
tuple(key.split(".")) if isinstance(key, str) else key: value
3939
for key, value in (kwargs or dict()).items()
4040
}
4141

test_autofit/analysis/test_latent_variables.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_set_database_paths(session, latent_samples):
4747
latent_samples=latent_samples,
4848
)
4949
loaded = database_paths.load_latent_samples()
50-
assert loaded.max_log_likelihood_sample.kwargs == {"fwhm": 7.0644601350928475}
50+
assert loaded.max_log_likelihood_sample.kwargs == {("fwhm",): 7.0644601350928475}
5151

5252

5353
@pytest.fixture(name="latent_samples")
@@ -73,7 +73,7 @@ def make_latent_samples():
7373

7474

7575
def test_compute_latent_samples(latent_samples):
76-
assert latent_samples.sample_list[0].kwargs == {"fwhm": 7.0644601350928475}
76+
assert latent_samples.sample_list[0].kwargs == {("fwhm",): 7.0644601350928475}
7777
assert latent_samples.model.instance_from_vector([1.0]).fwhm == 1.0
7878

7979

@@ -113,7 +113,7 @@ def test_compute_latent_samples_skips_fit_exception_samples():
113113
),
114114
)
115115
assert len(latent_samples.sample_list) == 1
116-
assert latent_samples.sample_list[0].kwargs == {"fwhm": 7.0644601350928475}
116+
assert latent_samples.sample_list[0].kwargs == {("fwhm",): 7.0644601350928475}
117117

118118

119119
def test_info(latent_samples):

test_autofit/database/paths/test_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_serialise_sample(sample):
4949
sample = m.Object.from_object(
5050
sample
5151
)()
52-
assert "centre" in sample.kwargs
52+
assert ("centre",) in sample.kwargs
5353

5454

5555
def test_load_samples(

test_autofit/non_linear/samples/test_efficient.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ def test_values(efficient):
4646

4747
sample_1, sample_2 = samples.sample_list
4848

49-
assert sample_1.kwargs["centre"] == 1.0
50-
assert sample_1.kwargs["normalization"] == 2.0
51-
assert sample_1.kwargs["sigma"] == 3.0
49+
assert sample_1.kwargs[("centre",)] == 1.0
50+
assert sample_1.kwargs[("normalization",)] == 2.0
51+
assert sample_1.kwargs[("sigma",)] == 3.0
5252

53-
assert sample_2.kwargs["centre"] == 4.0
54-
assert sample_2.kwargs["normalization"] == 5.0
55-
assert sample_2.kwargs["sigma"] == 6.0
53+
assert sample_2.kwargs[("centre",)] == 4.0
54+
assert sample_2.kwargs[("normalization",)] == 5.0
55+
assert sample_2.kwargs[("sigma",)] == 6.0
5656

5757

5858
def test_database(efficient, session):

test_autofit/non_linear/samples/test_samples.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,26 @@ def test__addition_of_samples__raises_error_if_model_mismatch(samples_x5):
241241

242242
with pytest.raises(af.exc.SamplesException):
243243
samples_x5 + samples_different_model
244+
245+
246+
def test__sample_kwargs__mixed_dotted_and_dotless_string_keys():
247+
sample = af.Sample(
248+
log_likelihood=1.0,
249+
log_prior=0.0,
250+
weight=1.0,
251+
kwargs={
252+
"mock_class_1.one": 10.0,
253+
"dummy_0": 99.0,
254+
},
255+
)
256+
257+
assert all(isinstance(key, tuple) for key in sample.kwargs)
258+
assert sample.is_path_kwargs is True
259+
assert sample.kwargs[("dummy_0",)] == 99.0
260+
assert sample.kwargs[("mock_class_1", "one")] == 10.0
261+
262+
paths = [
263+
[("mock_class_1", "one")],
264+
[("dummy_0",)],
265+
]
266+
assert sample.parameter_lists_for_paths(paths) == [10.0, 99.0]

0 commit comments

Comments
 (0)