@@ -2313,10 +2313,14 @@ def get_plot(
23132313 Returns:
23142314 go.Figure | plt.Axes: Plotly figure or matplotlib axes object depending on backend.
23152315 """
2316- fig = None
2317- data = []
2316+ if self ._dim not in {1 , 2 , 3 , 4 }:
2317+ raise ValueError (
2318+ f"Plotting is only supported for unary/binary/ternary/quaternary phase diagrams — got { self ._dim } D "
2319+ )
23182320
23192321 if self .backend == "plotly" :
2322+ data : list = []
2323+
23202324 if self ._dim != 1 :
23212325 data .append (self ._create_plotly_lines ())
23222326
@@ -2334,7 +2338,7 @@ def get_plot(
23342338 if self ._dim != 1 and not (self ._dim == 3 and self .ternary_style == "2d" ):
23352339 data .append (self ._create_plotly_stable_labels (label_stable ))
23362340
2337- if fill and self ._dim in [ 3 , 4 ] :
2341+ if fill and self ._dim in { 3 , 4 } :
23382342 data .extend (self ._create_plotly_fill ())
23392343
23402344 data .extend ([stable_marker_plot , unstable_marker_plot ])
@@ -2346,20 +2350,22 @@ def get_plot(
23462350 fig .layout = self ._create_plotly_figure_layout ()
23472351 fig .update_layout (coloraxis_colorbar = {"yanchor" : "top" , "y" : 0.05 , "x" : 1 })
23482352
2349- elif self .backend == "matplotlib" :
2350- if self ._dim <= 3 :
2351- fig = self ._get_matplotlib_2d_plot (
2353+ return fig
2354+
2355+ if self .backend == "matplotlib" :
2356+ if self ._dim in {1 , 2 , 3 }:
2357+ return self ._get_matplotlib_2d_plot (
23522358 label_stable ,
23532359 label_unstable ,
23542360 ordering ,
23552361 energy_colormap ,
23562362 ax = ax ,
23572363 process_attributes = process_attributes ,
23582364 )
2359- elif self ._dim == 4 :
2360- fig = self ._get_matplotlib_3d_plot (label_stable , ax = ax )
2365+ if self ._dim == 4 :
2366+ return self ._get_matplotlib_3d_plot (label_stable , ax = ax )
23612367
2362- return fig
2368+ return None
23632369
23642370 def show (self , * args , ** kwargs ) -> None :
23652371 """
0 commit comments