Skip to content

Commit de71e2a

Browse files
committed
more explicit handling of dimension
1 parent db8a991 commit de71e2a

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/pymatgen/analysis/phase_diagram.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2313,10 +2313,14 @@ def get_plot(
23132313
Returns:
23142314
go.Figure | plt.Axes: Plotly figure or matplotlib axes object depending on backend.
23152315
"""
2316-
fig = None
2317-
data = []
2316+
if self._dim not in {1, 2, 3, 4}:
2317+
raise ValueError(
2318+
f"Plotting is only supported for unary/binary/ternary/quaternary phase diagrams — got {self._dim}D "
2319+
)
23182320

23192321
if self.backend == "plotly":
2322+
data: list = []
2323+
23202324
if self._dim != 1:
23212325
data.append(self._create_plotly_lines())
23222326

@@ -2334,7 +2338,7 @@ def get_plot(
23342338
if self._dim != 1 and not (self._dim == 3 and self.ternary_style == "2d"):
23352339
data.append(self._create_plotly_stable_labels(label_stable))
23362340

2337-
if fill and self._dim in [3, 4]:
2341+
if fill and self._dim in {3, 4}:
23382342
data.extend(self._create_plotly_fill())
23392343

23402344
data.extend([stable_marker_plot, unstable_marker_plot])
@@ -2346,20 +2350,22 @@ def get_plot(
23462350
fig.layout = self._create_plotly_figure_layout()
23472351
fig.update_layout(coloraxis_colorbar={"yanchor": "top", "y": 0.05, "x": 1})
23482352

2349-
elif self.backend == "matplotlib":
2350-
if self._dim <= 3:
2351-
fig = self._get_matplotlib_2d_plot(
2353+
return fig
2354+
2355+
if self.backend == "matplotlib":
2356+
if self._dim in {1, 2, 3}:
2357+
return self._get_matplotlib_2d_plot(
23522358
label_stable,
23532359
label_unstable,
23542360
ordering,
23552361
energy_colormap,
23562362
ax=ax,
23572363
process_attributes=process_attributes,
23582364
)
2359-
elif self._dim == 4:
2360-
fig = self._get_matplotlib_3d_plot(label_stable, ax=ax)
2365+
if self._dim == 4:
2366+
return self._get_matplotlib_3d_plot(label_stable, ax=ax)
23612367

2362-
return fig
2368+
return None
23632369

23642370
def show(self, *args, **kwargs) -> None:
23652371
"""

0 commit comments

Comments
 (0)