Skip to content

Commit 7025e17

Browse files
authored
Merge pull request #203 from scipp/load_dream_csv
Add dream.load_geant4_csv
2 parents 1ff1045 + d63d207 commit 7025e17

File tree

8 files changed

+287
-0
lines changed

8 files changed

+287
-0
lines changed

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- ipykernel==6.25.1
1818
- ipywidgets==8.1.0
1919
- nbsphinx=0.9.2
20+
- pandas=2.0.3
2021
- pandoc=3.1.3
2122
- pip=23.2.1
2223
- plopp=23.09.0

src/ess/dream/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
# SPDX-License-Identifier: BSD-3-Clause
22
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
4+
from . import data
5+
from .io import load_geant4_csv
6+
7+
__all__ = ['data', 'load_geant4_csv']

src/ess/dream/data.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
_version = '1'
4+
5+
__all__ = ['get_path']
6+
7+
8+
def _make_pooch():
9+
import pooch
10+
11+
return pooch.create(
12+
path=pooch.os_cache('ess/dream'),
13+
env='ESS_DREAM_DATA_DIR',
14+
base_url='https://public.esss.dk/groups/scipp/ess/dream/{version}/',
15+
version=_version,
16+
registry={
17+
'data_dream_with_sectors.csv.zip': 'md5:52ae6eb3705e5e54306a001bc0ae85d8',
18+
},
19+
)
20+
21+
22+
_pooch = _make_pooch()
23+
24+
25+
def get_path(name: str) -> str:
26+
"""
27+
Return the path to a data file bundled with scippneutron.
28+
29+
This function only works with example data and cannot handle
30+
paths to custom files.
31+
"""
32+
return _pooch.fetch(name)

src/ess/dream/io/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
4+
from .geant4 import load_geant4_csv
5+
6+
__all__ = ['load_geant4_csv']

src/ess/dream/io/geant4.py

+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
4+
import os
5+
from io import BytesIO, StringIO
6+
from typing import Dict, Optional, Union
7+
8+
import numpy as np
9+
import scipp as sc
10+
11+
MANTLE_DETECTOR_ID = sc.index(7)
12+
HIGH_RES_DETECTOR_ID = sc.index(8)
13+
ENDCAPS_DETECTOR_IDS = tuple(map(sc.index, (3, 4, 5, 6)))
14+
15+
16+
def load_geant4_csv(
17+
filename: Union[str, os.PathLike, StringIO, BytesIO]
18+
) -> sc.DataGroup:
19+
"""Load a GEANT4 CSV file for DREAM.
20+
21+
Parameters
22+
----------
23+
filename:
24+
Path to the GEANT4 CSV file.
25+
26+
Returns
27+
-------
28+
:
29+
A :class:`scipp.DataGroup` containing the loaded events.
30+
"""
31+
events = _load_raw_events(filename)
32+
detectors = _split_detectors(events)
33+
for det in detectors.values():
34+
_adjust_coords(det)
35+
detectors = _group(detectors)
36+
37+
return sc.DataGroup({'instrument': sc.DataGroup(detectors)})
38+
39+
40+
def _load_raw_events(
41+
filename: Union[str, os.PathLike, StringIO, BytesIO]
42+
) -> sc.DataArray:
43+
table = sc.io.load_csv(filename, sep='\t', header_parser='bracket', data_columns=[])
44+
table = table.rename_dims(row='event')
45+
return sc.DataArray(
46+
sc.ones(sizes=table.sizes, with_variances=True, unit='counts'),
47+
coords=table.coords,
48+
)
49+
50+
51+
def _adjust_coords(da: sc.DataArray) -> None:
52+
da.coords['wavelength'] = da.coords.pop('lambda')
53+
da.coords['position'] = sc.spatial.as_vectors(
54+
da.coords.pop('x_pos'), da.coords.pop('y_pos'), da.coords.pop('z_pos')
55+
)
56+
57+
58+
def _group(detectors: Dict[str, sc.DataArray]) -> Dict[str, sc.DataArray]:
59+
elements = ('module', 'segment', 'counter', 'wire', 'strip')
60+
61+
def group(key: str, da: sc.DataArray) -> sc.DataArray:
62+
if key == 'high_resolution':
63+
# Only the HR detector has sectors.
64+
return da.group('sector', *elements)
65+
res = da.group(*elements)
66+
res.bins.coords.pop('sector', None)
67+
return res
68+
69+
return {key: group(key, da) for key, da in detectors.items()}
70+
71+
72+
def _split_detectors(
73+
data: sc.DataArray, detector_id_name: str = 'det ID'
74+
) -> Dict[str, sc.DataArray]:
75+
groups = data.group(
76+
sc.concat(
77+
[MANTLE_DETECTOR_ID, HIGH_RES_DETECTOR_ID, *ENDCAPS_DETECTOR_IDS],
78+
dim=detector_id_name,
79+
)
80+
)
81+
detectors = {}
82+
if (
83+
mantle := _extract_detector(groups, detector_id_name, MANTLE_DETECTOR_ID)
84+
) is not None:
85+
detectors['mantle'] = mantle.copy()
86+
if (
87+
high_res := _extract_detector(groups, detector_id_name, HIGH_RES_DETECTOR_ID)
88+
) is not None:
89+
detectors['high_resolution'] = high_res.copy()
90+
91+
endcaps_list = [
92+
det
93+
for i in ENDCAPS_DETECTOR_IDS
94+
if (det := _extract_detector(groups, detector_id_name, i)) is not None
95+
]
96+
if endcaps_list:
97+
endcaps = sc.concat(endcaps_list, data.dim)
98+
endcaps = endcaps.bin(
99+
z_pos=sc.array(
100+
dims=['z_pos'],
101+
values=[-np.inf, 0.0, np.inf],
102+
unit=endcaps.coords['z_pos'].unit,
103+
)
104+
)
105+
detectors['endcap_backward'] = endcaps[0].bins.concat().value.copy()
106+
detectors['endcap_forward'] = endcaps[1].bins.concat().value.copy()
107+
108+
return detectors
109+
110+
111+
def _extract_detector(
112+
detector_groups: sc.DataArray, detector_id_name: str, detector_id: sc.Variable
113+
) -> Optional[sc.DataArray]:
114+
try:
115+
return detector_groups[detector_id_name, detector_id].value
116+
except IndexError:
117+
return None

tests/dream/__init__.py

Whitespace-only changes.

tests/dream/io/__init__.py

Whitespace-only changes.

tests/dream/io/geant4_test.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3+
4+
import zipfile
5+
from io import BytesIO
6+
from typing import Optional, Set
7+
8+
import numpy as np
9+
import pytest
10+
import scipp as sc
11+
import scipp.testing
12+
13+
from ess.dream import data, load_geant4_csv
14+
15+
16+
@pytest.fixture(scope='module')
17+
def file_path():
18+
return data.get_path('data_dream_with_sectors.csv.zip')
19+
20+
21+
# Load file into memory only once
22+
@pytest.fixture(scope='module')
23+
def load_file(file_path):
24+
with zipfile.ZipFile(file_path, 'r') as archive:
25+
return archive.read(archive.namelist()[0])
26+
27+
28+
@pytest.fixture(scope='function')
29+
def file(load_file):
30+
return BytesIO(load_file)
31+
32+
33+
def assert_index_coord(
34+
coord: sc.Variable, *, values: Optional[Set[int]] = None
35+
) -> None:
36+
assert coord.ndim == 1
37+
assert coord.unit is None
38+
assert coord.dtype == 'int64'
39+
if values is not None:
40+
assert set(np.unique(coord.values)) == values
41+
42+
43+
def test_load_geant4_csv_loads_expected_structure(file):
44+
loaded = load_geant4_csv(file)
45+
assert isinstance(loaded, sc.DataGroup)
46+
assert loaded.keys() == {'instrument'}
47+
48+
instrument = loaded['instrument']
49+
assert isinstance(instrument, sc.DataGroup)
50+
assert instrument.keys() == {
51+
'mantle',
52+
'high_resolution',
53+
'endcap_forward',
54+
'endcap_backward',
55+
}
56+
57+
58+
@pytest.mark.parametrize(
59+
'key', ('mantle', 'high_resolution', 'endcap_forward', 'endcap_backward')
60+
)
61+
def test_load_gean4_csv_set_weights_to_one(file, key):
62+
detector = load_geant4_csv(file)['instrument'][key]
63+
events = detector.bins.constituents['data'].data
64+
sc.testing.assert_identical(
65+
events, sc.ones(sizes=events.sizes, with_variances=True, unit='counts')
66+
)
67+
68+
69+
def test_load_geant4_csv_mantle_has_expected_coords(file):
70+
# Only testing ranges that will not change in the future
71+
mantle = load_geant4_csv(file)['instrument']['mantle']
72+
assert_index_coord(mantle.coords['module'])
73+
assert_index_coord(mantle.coords['segment'])
74+
assert_index_coord(mantle.coords['counter'])
75+
assert_index_coord(mantle.coords['wire'], values=set(range(1, 33)))
76+
assert_index_coord(mantle.coords['strip'], values=set(range(1, 257)))
77+
assert 'sector' not in mantle.coords
78+
79+
assert 'sector' not in mantle.bins.coords
80+
assert 'tof' in mantle.bins.coords
81+
assert 'wavelength' in mantle.bins.coords
82+
assert 'position' in mantle.bins.coords
83+
84+
85+
def test_load_geant4_csv_endcap_backward_has_expected_coords(file):
86+
endcap = load_geant4_csv(file)['instrument']['endcap_backward']
87+
assert_index_coord(endcap.coords['module'])
88+
assert_index_coord(endcap.coords['segment'])
89+
assert_index_coord(endcap.coords['counter'])
90+
assert_index_coord(endcap.coords['wire'], values=set(range(1, 17)))
91+
assert_index_coord(endcap.coords['strip'], values=set(range(1, 17)))
92+
assert 'sector' not in endcap.coords
93+
94+
assert 'sector' not in endcap.bins.coords
95+
assert 'tof' in endcap.bins.coords
96+
assert 'wavelength' in endcap.bins.coords
97+
assert 'position' in endcap.bins.coords
98+
99+
100+
def test_load_geant4_csv_endcap_forward_has_expected_coords(file):
101+
endcap = load_geant4_csv(file)['instrument']['endcap_forward']
102+
assert_index_coord(endcap.coords['module'])
103+
assert_index_coord(endcap.coords['segment'])
104+
assert_index_coord(endcap.coords['counter'])
105+
assert_index_coord(endcap.coords['wire'], values=set(range(1, 17)))
106+
assert_index_coord(endcap.coords['strip'], values=set(range(1, 17)))
107+
assert 'sector' not in endcap.coords
108+
109+
assert 'sector' not in endcap.bins.coords
110+
assert 'tof' in endcap.bins.coords
111+
assert 'wavelength' in endcap.bins.coords
112+
assert 'position' in endcap.bins.coords
113+
114+
115+
def test_load_geant4_csv_high_resolution_has_expected_coords(file):
116+
hr = load_geant4_csv(file)['instrument']['high_resolution']
117+
assert_index_coord(hr.coords['module'])
118+
assert_index_coord(hr.coords['segment'])
119+
assert_index_coord(hr.coords['counter'])
120+
assert_index_coord(hr.coords['wire'], values=set(range(1, 17)))
121+
assert_index_coord(hr.coords['strip'], values=set(range(1, 33)))
122+
assert_index_coord(hr.coords['sector'], values=set(range(1, 5)))
123+
124+
assert 'tof' in hr.bins.coords
125+
assert 'wavelength' in hr.bins.coords
126+
assert 'position' in hr.bins.coords

0 commit comments

Comments
 (0)