Skip to content

Commit 1f1e3a4

Browse files
committed
Fix indexing issue for appending to NetCDF files
1 parent 3865bb0 commit 1f1e3a4

File tree

1 file changed

+33
-3
lines changed

1 file changed

+33
-3
lines changed

src/mdcraft/openmm/file.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def __init__(
5555
self._nc = file
5656
self._nc.set_always_mask(False)
5757

58-
if mode in {"a", "r+", "w"}:
58+
if mode.startswith(("a", "r+", "w", "x")):
5959
self._restart = restart
6060
else:
6161
self._restart = self._nc.Conventions == "AMBERRESTART"
62-
self._frame = self.get_num_frames()
62+
self._frame = self.get_num_frames() if hasattr(self._nc, "Conventions") else 0
6363

6464
def get_dimensions(
6565
self, frames: Union[int, list[int], slice] = None, units: bool = True
@@ -116,6 +116,11 @@ def get_num_frames(self) -> int:
116116
Number of frames.
117117
"""
118118

119+
if not hasattr(self._nc, "Conventions"):
120+
raise RuntimeError(
121+
"The NetCDF file is not a valid AMBER NetCDF "
122+
"trajectory or does not contain any data."
123+
)
119124
return self._nc.dimensions["frame"].size
120125

121126
def get_num_atoms(self) -> int:
@@ -128,7 +133,12 @@ def get_num_atoms(self) -> int:
128133
Number of atoms.
129134
"""
130135

131-
return self._nc.dimensions["atom"].size
136+
if not hasattr(self._nc, "Conventions"):
137+
raise RuntimeError(
138+
"The NetCDF file is not a valid AMBER NetCDF "
139+
"trajectory or does not contain any data."
140+
)
141+
return self._nc.dimensions["atom"].size if hasattr(self._nc, "Conventions") else 0
132142

133143
def get_times(
134144
self, frames: Union[int, list[int], slice] = None, units: bool = True
@@ -153,6 +163,11 @@ def get_times(
153163
**Reference unit**: :math:`\\mathrm{ps}`.
154164
"""
155165

166+
if not hasattr(self._nc, "Conventions"):
167+
raise RuntimeError(
168+
"The NetCDF file is not a valid AMBER NetCDF "
169+
"trajectory or does not contain any data."
170+
)
156171
times = (
157172
self._nc.variables["time"][:]
158173
if frames is None
@@ -185,6 +200,11 @@ def get_positions(
185200
**Reference unit**: :math:`\\mathrm{Å}`.
186201
"""
187202

203+
if not hasattr(self._nc, "Conventions"):
204+
raise RuntimeError(
205+
"The NetCDF file is not a valid AMBER NetCDF "
206+
"trajectory or does not contain any data."
207+
)
188208
positions = (
189209
self._nc.variables["coordinates"][:]
190210
if frames is None
@@ -218,6 +238,11 @@ def get_velocities(
218238
**Reference unit**: :math:`\\mathrm{Å/ps}`.
219239
"""
220240

241+
if not hasattr(self._nc, "Conventions"):
242+
raise RuntimeError(
243+
"The NetCDF file is not a valid AMBER NetCDF "
244+
"trajectory or does not contain any data."
245+
)
221246
if "velocities" not in self._nc.variables:
222247
wmsg = (
223248
"The NetCDF file does not contain information about "
@@ -259,6 +284,11 @@ def get_forces(
259284
**Reference unit**: :math:`\\mathrm{Å/ps}`.
260285
"""
261286

287+
if not hasattr(self._nc, "Conventions"):
288+
raise RuntimeError(
289+
"The NetCDF file is not a valid AMBER NetCDF "
290+
"trajectory or does not contain any data."
291+
)
262292
if "forces" not in self._nc.variables:
263293
wmsg = (
264294
"The NetCDF file does not contain information about "

0 commit comments

Comments
 (0)