Skip to content

Commit d0a6dfe

Browse files
committed
feat(zettaset): enable zettaset-wise resolution
1 parent b6178bd commit d0a6dfe

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

deepem/data/dataset/multi_zettaset.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ def _initialize_zettasets(
6262
if "path" not in spec or not spec["path"].startswith("gs://"):
6363
raise ValueError(f"Invalid zettaset specification for '{name}': missing or invalid 'path'.")
6464
zettaset_path = spec["path"]
65-
print(f"Zettaset {name} [{zettaset_path}]")
66-
zettasets[name] = Zettaset(zettaset_path, "", zettaset_resolution)
65+
print(f"Zettaset `{name}` from [{zettaset_path}]")
66+
resolution = tuple(spec.get("resolution", zettaset_resolution))
67+
print(f"{resolution=}")
68+
zettasets[name] = Zettaset(zettaset_path, "", resolution)
6769
return zettasets
6870

6971

@@ -94,6 +96,7 @@ def _process_sample(
9496
zettaset_padding: tuple[int, int, int] = (0, 0, 0),
9597
zettaset_padding_spec: dict[str, tuple[int, int, int]] = {},
9698
zettaset_mask: bool = True,
99+
zettaset_resolution: tuple[int, int, int] | None = None,
97100
**kwargs,
98101
) -> dict[str, dict[str, np.ndarray]]:
99102
if not is_valid_format(data_id):
@@ -116,11 +119,15 @@ def _process_sample(
116119
padding = zettaset_padding_spec.get(data_id, zettaset_spec.get("padding", zettaset_padding))
117120
no_mask = zettaset_spec.get("no_mask", not zettaset_mask)
118121

122+
# Determine resolution: zettaset-specific overrides zettaset_resolution
123+
resolution = tuple(zettaset_spec.get("resolution", zettaset_resolution))
124+
119125
print(f"Sample [{data_id}]")
120126
return {data_id: load_sample(
121127
zettaset.samples[sample_name],
122128
padding,
123129
no_mask,
130+
resolution,
124131
**kwargs,
125132
)}
126133

@@ -129,8 +136,8 @@ def load_sample(
129136
sample: Sample,
130137
padding: tuple[int, int, int] = (0, 0, 0),
131138
no_mask: bool = False,
139+
resolution: tuple[int, int, int] | None = None,
132140
zettaset_lookup: dict[str, str] | None = None,
133-
zettaset_resolution: tuple[int, int, int] | None = None,
134141
requires_binarize: list[str] = [],
135142
zettaset_share_mask: str | None = None,
136143
**kwargs
@@ -143,7 +150,7 @@ def convert_array(arr: ArrayLike) -> np.ndarray:
143150
dset: dict[str, np.ndarray] = {}
144151

145152
# Bbox with padding
146-
resolution = zettaset_resolution or sample.base_resolution
153+
resolution = resolution or sample.base_resolution
147154
bbox = sample.bbox * (sample.base_resolution / np.array(resolution))
148155
xyz_padding = (
149156
tuple(reversed(padding))

0 commit comments

Comments
 (0)