Skip to content

Commit 8199147

Browse files
committed
Very quick and dirty implementation of excluded regions
1 parent 1ffc2ad commit 8199147

File tree

11 files changed

+2121
-4
lines changed

11 files changed

+2121
-4
lines changed

src/easydiffraction/analysis/minimization.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,21 @@ def _residual_function(self,
121121
# Sync parameters back to objects
122122
self.minimizer._sync_result_to_parameters(parameters, engine_params)
123123

124+
# Update the excluded points in experiments
125+
# TODO: This should not be handled here every time
126+
# This implementation is just very quick and dirty
127+
for experiment in experiments._items.values():
128+
experiment.datastore.pattern.excluded = np.full(experiment.datastore.pattern.x.shape,
129+
fill_value=False,
130+
dtype=bool) # Reset excluded points
131+
excluded_regions = experiment.excluded_regions._items # List of excluded regions
132+
if excluded_regions: # If there are any excluded regions
133+
for idx, point in enumerate(experiment.datastore.pattern.x): # Set excluded points
134+
for region in excluded_regions.values():
135+
if region.minimum.value <= point <= region.maximum.value:
136+
experiment.datastore.pattern.excluded[idx] = True
137+
break
138+
124139
# Prepare weights for joint fitting
125140
num_expts: int = len(experiments.ids)
126141
if weights is None:
@@ -145,7 +160,14 @@ def _residual_function(self,
145160
called_by_minimizer=True) # True False
146161
y_meas: np.ndarray = experiment.datastore.pattern.meas
147162
y_meas_su: np.ndarray = experiment.datastore.pattern.meas_su
148-
diff: np.ndarray = (y_meas - y_calc) / y_meas_su
163+
excluded: np.ndarray = experiment.datastore.pattern.excluded
164+
# TODO: Excluded points must be handled differently.
165+
# They should not contribute to the residuals.
166+
diff = np.where(
167+
excluded,
168+
0.0, # Excluded points contribute zero to the residuals
169+
(y_meas - y_calc) / y_meas_su # Normalized residuals
170+
)
149171
diff *= np.sqrt(weight) # Residuals are squared before going into reduced chi-squared
150172
residuals.extend(diff)
151173

src/easydiffraction/experiments/collections/datastore.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, experiment: Experiment) -> None:
1616
self.meas: Optional[np.ndarray] = None
1717
self.meas_su: Optional[np.ndarray] = None
1818
self.bkg: Optional[np.ndarray] = None
19+
self.excluded: Optional[np.ndarray] = None # Excluded points
1920
self._calc: Optional[np.ndarray] = None # Cached calculated intensities
2021

2122
@property
@@ -33,6 +34,7 @@ class PowderPattern(Pattern):
3334
"""
3435
Specialized pattern for powder diffraction (can be extended in the future).
3536
"""
37+
# TODO: Check if this class is needed or if it can be merged with Pattern
3638
def __init__(self, experiment: Experiment) -> None:
3739
super().__init__(experiment)
3840
# Additional powder-specific initialization if needed
@@ -49,7 +51,7 @@ def __init__(self, sample_form: str, experiment: Experiment) -> None:
4951
if sample_form == "powder":
5052
self.pattern: Pattern = PowderPattern(experiment)
5153
elif sample_form == "single_crystal":
52-
self.pattern: Pattern = Pattern(experiment)
54+
self.pattern: Pattern = Pattern(experiment) # TODO: Find better name for single crystal pattern
5355
else:
5456
raise ValueError(f"Unknown sample form '{sample_form}'")
5557

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from typing import Type
2+
3+
from easydiffraction.core.objects import (
4+
Parameter,
5+
Descriptor,
6+
Component,
7+
Collection
8+
)
9+
10+
11+
class ExcludedRegion(Component):
12+
@property
13+
def category_key(self) -> str:
14+
return "excluded_region"
15+
16+
@property
17+
def cif_category_key(self) -> str:
18+
return "excluded_region"
19+
20+
def __init__(self,
21+
minimum: float,
22+
maximum: float):
23+
super().__init__()
24+
25+
self.minimum = Descriptor(
26+
value=minimum,
27+
name="minimum",
28+
cif_name="minimum"
29+
)
30+
self.maximum = Parameter(
31+
value=maximum,
32+
name="maximum",
33+
cif_name="maximum"
34+
)
35+
36+
# Select which of the input parameters is used for the
37+
# as ID for the whole object
38+
self._entry_id = f'{minimum}-{maximum}'
39+
40+
# Lock further attribute additions to prevent
41+
# accidental modifications by users
42+
self._locked = True
43+
44+
45+
class ExcludedRegions(Collection):
46+
"""
47+
Collection of LinkedPhase instances.
48+
"""
49+
@property
50+
def _type(self) -> str:
51+
return "category" # datablock or category
52+
53+
@property
54+
def _child_class(self) -> Type[ExcludedRegion]:
55+
return ExcludedRegion

src/easydiffraction/experiments/experiment.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from easydiffraction.experiments.collections.linked_phases import LinkedPhases
1717
from easydiffraction.experiments.collections.background import BackgroundFactory
18+
from easydiffraction.experiments.collections.excluded_regions import ExcludedRegions
1819
from easydiffraction.experiments.collections.datastore import DatastoreFactory
1920

2021
from easydiffraction.utils.formatting import paragraph, warning
@@ -129,6 +130,10 @@ def as_cif(self, max_points: Optional[int] = None) -> str:
129130
if hasattr(self, "background") and self.background:
130131
cif_lines += ["", self.background.as_cif()]
131132

133+
# Excluded regions
134+
if hasattr(self, "excluded_regions") and self.excluded_regions:
135+
cif_lines += ["", self.excluded_regions.as_cif()]
136+
132137
# Measured data
133138
if hasattr(self, "datastore") and hasattr(self.datastore, "pattern"):
134139
cif_lines.append("")
@@ -191,7 +196,8 @@ def __init__(self,
191196
beam_mode=self.type.beam_mode.value,
192197
profile_type=self._peak_profile_type)
193198

194-
self.linked_phases = LinkedPhases()
199+
self.linked_phases: LinkedPhases = LinkedPhases()
200+
self.excluded_regions: ExcludedRegions = ExcludedRegions()
195201

196202
@abstractmethod
197203
def _load_ascii_data_to_experiment(self, data_path: str) -> None:

src/easydiffraction/plotting/plotting.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def plot_meas(self,
140140
x_min=x_min,
141141
x_max=x_max)
142142

143+
# Exclude points based on the pattern's excluded mask
144+
excluded = pattern.excluded
145+
x = x[~excluded]
146+
y_meas = y_meas[~excluded]
147+
143148
y_series = [y_meas]
144149
y_labels = ['meas']
145150

@@ -176,6 +181,11 @@ def plot_calc(self,
176181
x_min=x_min,
177182
x_max=x_max)
178183

184+
# Exclude points based on the pattern's excluded mask
185+
excluded = pattern.excluded
186+
x = x[~excluded]
187+
y_calc = y_calc[~excluded]
188+
179189
y_series = [y_calc]
180190
y_labels = ['calc']
181191

@@ -220,6 +230,12 @@ def plot_meas_vs_calc(self,
220230
x_min=x_min,
221231
x_max=x_max)
222232

233+
# Exclude points based on the pattern's excluded mask
234+
excluded = pattern.excluded
235+
x = x[~excluded]
236+
y_meas = y_meas[~excluded]
237+
y_calc = y_calc[~excluded]
238+
223239
y_series = [y_meas, y_calc]
224240
y_labels = ['meas', 'calc']
225241

tests/functional_tests/fitting/test_powder-diffraction_multiphase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_single_fit_neutron_pd_tof_mcstas_lbco_si() -> None:
8484

8585
# Compare fit quality
8686
assert_almost_equal(project.analysis.fit_results.reduced_chi_square,
87-
desired=2.87,
87+
desired=1.79, # 2.87
8888
decimal=1)
8989

9090

tutorials/cryst-struct_pd-neut-tof_multphase-LBCO-Si_McStas.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,12 @@
211211
# %%
212212
project.analysis.current_minimizer = 'lmfit (leastsq)'
213213

214+
# %% [markdown]
215+
# ### Set Excluded Regions
216+
217+
# %%
218+
project.experiments['mcstas'].excluded_regions.add(minimum=108000, maximum=200000)
219+
214220
# %% [markdown]
215221
# ### Set Fitting Parameters
216222
#

0 commit comments

Comments
 (0)