Skip to content

Commit 7264008

Browse files
committed
Refactors excluded points handling in diffraction minimization
Moves the logic for updating excluded points from minimization to ExcludedRegions, ensuring better modularity.
1 parent 000edc8 commit 7264008

File tree

2 files changed

+31
-23
lines changed

2 files changed

+31
-23
lines changed

src/easydiffraction/analysis/minimization.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -121,21 +121,6 @@ def _residual_function(self,
121121
# Sync parameters back to objects
122122
self.minimizer._sync_result_to_parameters(parameters, engine_params)
123123

124-
# Update the excluded points in experiments
125-
# TODO: This should not be handled here every time
126-
# This implementation is just very quick and dirty
127-
for experiment in experiments._items.values():
128-
experiment.datastore.pattern.excluded = np.full(experiment.datastore.pattern.x.shape,
129-
fill_value=False,
130-
dtype=bool) # Reset excluded points
131-
excluded_regions = experiment.excluded_regions._items # List of excluded regions
132-
if excluded_regions: # If there are any excluded regions
133-
for idx, point in enumerate(experiment.datastore.pattern.x): # Set excluded points
134-
for region in excluded_regions.values():
135-
if region.minimum.value <= point <= region.maximum.value:
136-
experiment.datastore.pattern.excluded[idx] = True
137-
break
138-
139124
# Prepare weights for joint fitting
140125
num_expts: int = len(experiments.ids)
141126
if weights is None:
@@ -161,13 +146,7 @@ def _residual_function(self,
161146
y_meas: np.ndarray = experiment.datastore.pattern.meas
162147
y_meas_su: np.ndarray = experiment.datastore.pattern.meas_su
163148
excluded: np.ndarray = experiment.datastore.pattern.excluded
164-
# TODO: Excluded points must be handled differently.
165-
# They should not contribute to the residuals.
166-
diff = np.where(
167-
excluded,
168-
0.0, # Excluded points contribute zero to the residuals
169-
(y_meas - y_calc) / y_meas_su # Normalized residuals
170-
)
149+
diff = ((y_meas - y_calc) / y_meas_su)[~excluded] # Exclude points that are marked as excluded
171150
diff *= np.sqrt(weight) # Residuals are squared before going into reduced chi-squared
172151
residuals.extend(diff)
173152

src/easydiffraction/experiments/collections/excluded_regions.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from typing import Type
23

34
from easydiffraction.core.objects import (
@@ -44,7 +45,7 @@ def __init__(self,
4445

4546
class ExcludedRegions(Collection):
4647
"""
47-
Collection of LinkedPhase instances.
48+
Collection of ExcludedRegion instances.
4849
"""
4950
@property
5051
def _type(self) -> str:
@@ -53,3 +54,31 @@ def _type(self) -> str:
5354
@property
5455
def _child_class(self) -> Type[ExcludedRegion]:
5556
return ExcludedRegion
57+
58+
def on_item_added(self, item: ExcludedRegion) -> None:
59+
"""
60+
Called when a new item is added to the collection.
61+
"""
62+
# Update the excluded points in experiments
63+
# TODO: This implementation is very quick and dirty
64+
# It should be improved to only update the points that are affected
65+
# by the new excluded region, not all of them
66+
67+
#expt_name = self.datablock_id
68+
#minimum = item.minimum.value
69+
#maximum = item.maximum.value
70+
71+
experiment = self._parent
72+
excluded_regions = experiment.excluded_regions._items # List of excluded regions
73+
74+
if excluded_regions: # If there are any excluded regions
75+
pattern = experiment.datastore.pattern
76+
pattern.excluded = np.full(pattern.x.shape,
77+
fill_value=False,
78+
dtype=bool) # Reset excluded points
79+
80+
for idx, point in enumerate(pattern.x): # Set excluded points
81+
for region in excluded_regions.values():
82+
if region.minimum.value <= point <= region.maximum.value:
83+
experiment.datastore.pattern.excluded[idx] = True
84+
break

0 commit comments

Comments
 (0)