diff --git a/mne/source_estimate.py b/mne/source_estimate.py index deeb3a43ede..02eec4ec15e 100644 --- a/mne/source_estimate.py +++ b/mne/source_estimate.py @@ -1884,7 +1884,7 @@ class SourceEstimate(_BaseSurfaceSourceEstimate): """ @verbose - def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): + def save(self, fname, ftype="auto", overwrite=False, verbose=None): """Save the source estimates to a file. Parameters @@ -1894,18 +1894,29 @@ def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): spaces are obtained by adding ``"-lh.stc"`` and ``"-rh.stc"`` (or ``"-lh.w"`` and ``"-rh.w"``) to the stem provided, for the left and the right hemisphere, respectively. - ftype : str - File format to use. Allowed values are ``"stc"`` (default), - ``"w"``, and ``"h5"``. The ``"w"`` format only supports a single - time point. + ftype : "auto" | "stc" | "w" | "h5" + File format to use. If "auto", the file format will be inferred from the + file extension if possible. Other allowed values are ``"stc"``, ``"w"``, and + ``"h5"``. The ``"w"`` format only supports a single time point. %(overwrite)s .. versionadded:: 1.0 %(verbose)s """ fname = str(_check_fname(fname=fname, overwrite=True)) # checked below + if ftype == "auto": + if fname.endswith((".stc", "-lh.stc", "-rh.stc")): + ftype = "stc" + elif fname.endswith((".w", "-lh.w", "-rh.w")): + ftype = "w" + elif fname.endswith(".h5"): + ftype = "h5" + else: + logger.info( + "Cannot infer file type from `fname`; falling back to `.stc` format" + ) + ftype = "stc" _check_option("ftype", ftype, ["stc", "w", "h5"]) - lh_data = self.data[: len(self.lh_vertno)] rh_data = self.data[-len(self.rh_vertno) :] @@ -1918,6 +1929,8 @@ def save(self, fname, ftype="stc", *, overwrite=False, verbose=None): "real numbers before saving." ) logger.info("Writing STC to disk...") + if fname.endswith(".stc"): + fname = fname[:-4] fname_l = str(_check_fname(fname + "-lh.stc", overwrite=overwrite)) fname_r = str(_check_fname(fname + "-rh.stc", overwrite=overwrite)) _write_stc( diff --git a/mne/tests/test_source_estimate.py b/mne/tests/test_source_estimate.py index e4fa5a36b25..fba47d27cb4 100644 --- a/mne/tests/test_source_estimate.py +++ b/mne/tests/test_source_estimate.py @@ -482,7 +482,6 @@ def test_io_stc(tmp_path): stc = _fake_stc() stc.save(tmp_path / "tmp.stc") stc2 = read_source_estimate(tmp_path / "tmp.stc") - assert_array_almost_equal(stc.data, stc2.data) assert_array_almost_equal(stc.tmin, stc2.tmin) assert_equal(len(stc.vertices), len(stc2.vertices))