From b63ca3e2f8faebb2a7d2311d1382e8e89c04d096 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sat, 4 Oct 2025 13:16:39 +0200 Subject: [PATCH 1/5] add types for Simplex cast np.bool to bool --- src/pymatgen/util/coord.py | 45 +++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/src/pymatgen/util/coord.py b/src/pymatgen/util/coord.py index 2186eec84e3..038e3f490e6 100644 --- a/src/pymatgen/util/coord.py +++ b/src/pymatgen/util/coord.py @@ -367,14 +367,14 @@ class Simplex(MSONable): simplex_dim (int): Dimension of the simplex coordinate space. """ - def __init__(self, coords) -> None: + def __init__(self, coords: Sequence[Sequence[float]]) -> None: """Initialize a Simplex from vertex coordinates. Args: coords ([[float]]): Coords of the vertices of the simplex. e.g. [[1, 2, 3], [2, 4, 5], [6, 7, 8], [8, 9, 10]. """ - self._coords = np.array(coords) + self._coords = np.asarray(coords) self.space_dim, self.simplex_dim = self._coords.shape self.origin = self._coords[-1] if self.space_dim == self.simplex_dim + 1: @@ -382,12 +382,25 @@ def __init__(self, coords) -> None: self._aug = np.concatenate([coords, np.ones((self.space_dim, 1))], axis=-1) self._aug_inv = np.linalg.inv(self._aug) + def __eq__(self, other: object) -> bool: + if not isinstance(other, Simplex): + return NotImplemented + return any(np.allclose(p, other.coords) for p in itertools.permutations(self._coords)) + + def __hash__(self) -> int: + return len(self._coords) + + def __repr__(self) -> str: + output = [f"{self.simplex_dim}-simplex in {self.space_dim}D space\nVertices:"] + output += [f"\t({', '.join(map(str, coord))})" for coord in self._coords] + return "\n".join(output) + @property def volume(self) -> float: """Volume of the simplex.""" return abs(np.linalg.det(self._aug)) / math.factorial(self.simplex_dim) - def bary_coords(self, point): + def bary_coords(self, point: ArrayLike) -> np.ndarray: """ Args: point (ArrayLike): Point coordinates. @@ -400,7 +413,7 @@ def bary_coords(self, point): except AttributeError as exc: raise ValueError("Simplex is not full-dimensional") from exc - def point_from_bary_coords(self, bary_coords: ArrayLike): + def point_from_bary_coords(self, bary_coords: ArrayLike) -> np.ndarray: """ Args: bary_coords (ArrayLike): Barycentric coordinates (d+1, d). @@ -428,9 +441,14 @@ def in_simplex(self, point: Sequence[float], tolerance: float = 1e-8) -> bool: point (list[float]): Point to test tolerance (float): Tolerance to test if point is in simplex. """ - return (self.bary_coords(point) >= -tolerance).all() - - def line_intersection(self, point1: Sequence[float], point2: Sequence[float], tolerance: float = 1e-8): + return bool((self.bary_coords(point) >= -tolerance).all()) + + def line_intersection( + self, + point1: Sequence[float], + point2: Sequence[float], + tolerance: float = 1e-8, + ) -> list[np.ndarray]: """Compute the intersection points of a line with a simplex. Args: @@ -465,19 +483,6 @@ def line_intersection(self, point1: Sequence[float], point2: Sequence[float], to raise ValueError("More than 2 intersections found") return [self.point_from_bary_coords(b) for b in barys] - def __eq__(self, other: object) -> bool: - if not isinstance(other, Simplex): - return NotImplemented - return any(np.allclose(p, other.coords) for p in itertools.permutations(self._coords)) - - def __hash__(self) -> int: - return len(self._coords) - - def __repr__(self) -> str: - output = [f"{self.simplex_dim}-simplex in {self.space_dim}D space\nVertices:"] - output += [f"\t({', '.join(map(str, coord))})" for coord in self._coords] - return "\n".join(output) - @property def coords(self) -> np.ndarray: """A copy of the vertex coordinates in the simplex.""" From a5de0e9eebfc4081a78090fc0052784afa57dc30 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sat, 4 Oct 2025 13:27:55 +0200 Subject: [PATCH 2/5] use asarray on what could already be array --- src/pymatgen/analysis/phase_diagram.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index aa7ca7bb275..86c8f5d1c95 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -508,7 +508,7 @@ def all_entries_hulldata(self): [e.composition.get_atomic_fraction(el) for el in self.elements] + [e.energy_per_atom] for e in self.all_entries ] - return np.array(data)[:, 1:] + return np.asarray(data)[:, 1:] @property def unstable_entries(self) -> set[Entry]: @@ -660,7 +660,7 @@ def _get_simplex_intersections(self, c1, c2): for sc in self.simplexes: intersections.extend(sc.line_intersection(c1, c2)) - return np.array(intersections) + return np.asarray(intersections) def get_decomposition(self, comp: Composition) -> dict[PDEntry, float]: """ @@ -1993,7 +1993,7 @@ def fmt(fl): try: mat = [[entry.composition.get_atomic_fraction(el) for el in elements] for entry in face_entries] mat.append(comp_vec2 - comp_vec1) - matrix = np.array(mat).T + matrix = np.asarray(mat).T coeffs = np.linalg.solve(matrix, comp_vec2) x = coeffs[-1] @@ -2556,7 +2556,7 @@ def get_contour_pd_plot(self): """ pd = self._pd entries = pd.qhull_entries - data = np.array(pd.qhull_data) + data = np.asarray(pd.qhull_data) ax = self._get_matplotlib_2d_plot() data[:, 0:2] = triangular_coord(data[:, 0:2]).transpose() @@ -2599,9 +2599,9 @@ def pd_plot_data(self): """ pd = self._pd entries = pd.qhull_entries - data = np.array(pd.qhull_data) - lines = [] - stable_entries = {} + data = np.asarray(pd.qhull_data) + lines: list = [] + stable_entries: dict = {} for line in self.lines: entry1 = entries[line[0]] @@ -2623,7 +2623,7 @@ def pd_plot_data(self): stable_entries[label_coord[1]] = entry2 all_entries = pd.all_entries - all_data = np.array(pd.all_entries_hulldata) + all_data = np.asarray(pd.all_entries_hulldata) unstable_entries = {} stable = pd.stable_entries @@ -2804,7 +2804,7 @@ def _create_plotly_fill(self): ) ] elif self._dim == 3 and self.ternary_style == "3d": - facets = np.array(self._pd.facets) + facets = np.asarray(self._pd.facets) coords = np.array( [ triangular_coord(c) @@ -2847,7 +2847,7 @@ def _create_plotly_fill(self): ) ) elif self._dim == 4: - all_data = np.array(pd.qhull_data) + all_data = np.asarray(pd.qhull_data) fillcolors = itertools.cycle(plotly_layouts["default_fill_colors"]) for _idx, facet in enumerate(pd.facets): xs, ys, zs = [], [], [] @@ -3712,7 +3712,7 @@ def _get_matplotlib_2d_plot( # The follow defines an offset for the annotation text emanating # from the center of the PD. Results in fairly nice layouts for the # most part. - vec = np.array(coords) - center + vec = np.asarray(coords) - center vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec valign = "bottom" if vec[1] > 0 else "top" if vec[0] < -0.01: @@ -3758,7 +3758,7 @@ def _get_matplotlib_2d_plot( for entry, coords in unstable.items(): ehull = self._pd.get_e_above_hull(entry) if ehull is not None and ehull < self.show_unstable: - vec = np.array(coords) - center + vec = np.asarray(coords) - center vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec label = entry.name if energy_colormap is None: @@ -3880,7 +3880,7 @@ def triangular_coord(coord): """ unit_vec = np.array([[1, 0], [0.5, math.sqrt(3) / 2]]) - result = np.dot(np.array(coord), unit_vec) + result = np.dot(np.asarray(coord), unit_vec) return result.transpose() @@ -3902,7 +3902,7 @@ def tet_coord(coord): [0.5, 1 / 3 * math.sqrt(3) / 2, math.sqrt(6) / 3], ] ) - result = np.dot(np.array(coord), unitvec) + result = np.dot(np.asarray(coord), unitvec) return result.transpose() From db8a991c1096666aa2e931d37f4019c5c0c67a6a Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sat, 4 Oct 2025 12:52:54 +0200 Subject: [PATCH 3/5] add type and minor code cleanup for PD plotter add types for PDEntry clean up types for computed_data in PD __init__ add more types for PDPlotter add more types relocate dunder methods to merge --- src/pymatgen/analysis/phase_diagram.py | 420 +++++++++++++++---------- 1 file changed, 247 insertions(+), 173 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 86c8f5d1c95..0afe6306cc8 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -38,11 +38,14 @@ if TYPE_CHECKING: from collections.abc import Collection, Iterator, Sequence from io import StringIO - from typing import Any, Literal + from typing import Any, ClassVar, Literal - from numpy.typing import ArrayLike + from matplotlib.colors import Colormap + from numpy.typing import ArrayLike, NDArray from typing_extensions import Self + from pymatgen.entries.computed_entries import ComputedEntry + logger = logging.getLogger(__name__) with open( @@ -73,7 +76,7 @@ def __init__( energy: float, name: str | None = None, attribute: object = None, - ): + ) -> None: """ Args: composition (Composition): Composition @@ -86,7 +89,7 @@ def __init__( self.name = name or self.reduced_formula self.attribute = attribute - def __repr__(self): + def __repr__(self) -> str: name = "" if self.name != self.reduced_formula: name = f" ({self.name})" @@ -97,12 +100,12 @@ def energy(self) -> float: """The entry's energy.""" return self._energy - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """Get MSONable dict representation of PDEntry.""" return super().as_dict() | {"name": self.name, "attribute": self.attribute} @classmethod - def from_dict(cls, dct: dict) -> Self: + def from_dict(cls, dct: dict[str, Any]) -> Self: """ Args: dct (dict): dictionary representation of PDEntry. @@ -125,7 +128,7 @@ class GrandPotPDEntry(PDEntry): dict. """ - def __init__(self, entry, chempots, name=None): + def __init__(self, entry: PDEntry, chempots: dict[Element, float], name: str | None = None) -> None: """ Args: entry: A PDEntry-like object. @@ -146,6 +149,16 @@ def __init__(self, entry, chempots, name=None): self.original_comp = self._composition self.chempots = chempots + def __repr__(self): + output = [ + ( + f"GrandPotPDEntry with original composition {self.original_entry.composition}, " + f"energy = {self.original_entry.energy:.4f}, " + ), + "chempots = " + ", ".join(f"mu_{el} = {mu:.4f}" for el, mu in self.chempots.items()), + ] + return "".join(output) + @property def composition(self) -> Composition: """The composition after removing free species. @@ -156,7 +169,7 @@ def composition(self) -> Composition: return Composition({el: self._composition[el] for el in self._composition.elements if el not in self.chempots}) @property - def chemical_energy(self): + def chemical_energy(self) -> float: """The chemical energy term mu*N in the grand potential. Returns: @@ -169,17 +182,7 @@ def energy(self) -> float: """Grand potential energy.""" return self._energy - self.chemical_energy - def __repr__(self): - output = [ - ( - f"GrandPotPDEntry with original composition {self.original_entry.composition}, " - f"energy = {self.original_entry.energy:.4f}, " - ), - "chempots = " + ", ".join(f"mu_{el} = {mu:.4f}" for el, mu in self.chempots.items()), - ] - return "".join(output) - - def as_dict(self): + def as_dict(self) -> dict[str, Any]: """Get MSONable dict representation of GrandPotPDEntry.""" return { "@module": type(self).__module__, @@ -212,7 +215,7 @@ class TransformedPDEntry(PDEntry): """ # Tolerance for determining if amount of a composition is positive. - amount_tol = 1e-5 + amount_tol: ClassVar[float] = 1e-5 def __init__(self, entry, sp_mapping, name=None): """ @@ -236,6 +239,14 @@ def __init__(self, entry, sp_mapping, name=None): if not all(self.rxn.get_coeff(comp) <= TransformedPDEntry.amount_tol for comp in self.sp_mapping): raise TransformedPDEntryError("Only reactions with positive amounts of reactants allowed") + def __repr__(self): + output = [ + f"TransformedPDEntry {self.composition}", + f" with original composition {self.original_entry.composition}", + f", energy = {self.original_entry.energy:.4f}", + ] + return "".join(output) + @property def composition(self) -> Composition: """The composition in the dummy species space. @@ -254,14 +265,6 @@ def composition(self) -> Composition: return Composition(trans_comp) - def __repr__(self): - output = [ - f"TransformedPDEntry {self.composition}", - f" with original composition {self.original_entry.composition}", - f", energy = {self.original_entry.energy:.4f}", - ] - return "".join(output) - def as_dict(self): """Get MSONable dict representation of TransformedPDEntry.""" return { @@ -336,8 +339,8 @@ class PhaseDiagram(MSONable): """ # Tolerance for determining if formation energy is positive. - formation_energy_tol = 1e-11 - numerical_tol = 1e-8 + formation_energy_tol: ClassVar[float] = 1e-11 + numerical_tol: ClassVar[float] = 1e-8 def __init__( self, @@ -366,6 +369,7 @@ def __init__( self.elements = elements self.entries = entries + if computed_data is None: computed_data = self._compute() else: @@ -375,19 +379,32 @@ def __init__( # Update keys to be Element objects in case they are strings in pre-computed data computed_data["el_refs"] = [(Element(el_str), entry) for el_str, entry in computed_data["el_refs"]] + self.computed_data = computed_data - self.facets = computed_data["facets"] - self.simplexes = computed_data["simplexes"] - self.all_entries = computed_data["all_entries"] - self.qhull_data = computed_data["qhull_data"] - self.dim = computed_data["dim"] - self.el_refs = dict(computed_data["el_refs"]) - self.qhull_entries = tuple(computed_data["qhull_entries"]) - self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries) - self._stable_entries = tuple({self.qhull_entries[idx] for idx in set(itertools.chain(*self.facets))}) + + self.facets: list[NDArray[int]] = computed_data["facets"] + self.simplexes: list[Simplex] = computed_data["simplexes"] + self.all_entries: list[PDEntry] = computed_data["all_entries"] + self.qhull_data: np.ndarray = computed_data["qhull_data"] + self.dim: int = computed_data["dim"] + self.el_refs: dict[Element, PDEntry] = dict(computed_data["el_refs"]) + self.qhull_entries: tuple[PDEntry, ...] = tuple(computed_data["qhull_entries"]) + self._qhull_spaces: tuple = tuple(frozenset(e.elements) for e in self.qhull_entries) + self._stable_entries: tuple[PDEntry, ...] = tuple( + {self.qhull_entries[idx] for idx in set(itertools.chain(*self.facets))} + ) self._stable_spaces = tuple(frozenset(e.elements) for e in self._stable_entries) - def as_dict(self): + def __repr__(self) -> str: + symbols = [el.symbol for el in self.elements] + output = [ + f"{'-'.join(symbols)} phase diagram", + f"{len(self.stable_entries)} stable phases: ", + ", ".join(entry.name for entry in sorted(self.stable_entries, key=str)), + ] + return "\n".join(output) + + def as_dict(self) -> dict[str, Any]: """Get MSONable dict representation of PhaseDiagram.""" return { "@module": type(self).__module__, @@ -502,7 +519,7 @@ def pd_coords(self, comp: Composition) -> np.ndarray: return np.array([comp.get_atomic_fraction(el) for el in self.elements[1:]]) @property - def all_entries_hulldata(self): + def all_entries_hulldata(self) -> np.ndarray: """The ndarray used to construct the convex hull.""" data = [ [e.composition.get_atomic_fraction(el) for el in self.elements] + [e.energy_per_atom] @@ -527,7 +544,7 @@ def stable_entries(self) -> set[Entry]: return set(self._stable_entries) @lru_cache(1) # noqa: B019 - def _get_stable_entries_in_space(self, space) -> list[Entry]: + def _get_stable_entries_in_space(self, space: set[Element]) -> list[Entry]: """ Args: space (set[Element]): set of Element objects. @@ -584,15 +601,6 @@ def get_form_energy_per_atom(self, entry: PDEntry) -> float: """ return self.get_form_energy(entry) / entry.composition.num_atoms - def __repr__(self) -> str: - symbols = [el.symbol for el in self.elements] - output = [ - f"{'-'.join(symbols)} phase diagram", - f"{len(self.stable_entries)} stable phases: ", - ", ".join(entry.name for entry in sorted(self.stable_entries, key=str)), - ] - return "\n".join(output) - @lru_cache(1) # noqa: B019 def _get_facet_and_simplex(self, comp: Composition) -> tuple[Simplex, Simplex]: """Get any facet that a composition falls into. Cached so successive @@ -608,7 +616,7 @@ def _get_facet_and_simplex(self, comp: Composition) -> tuple[Simplex, Simplex]: raise RuntimeError(f"No facet found for {comp = }") - def _get_all_facets_and_simplexes(self, comp): + def _get_all_facets_and_simplexes(self, comp: Composition) -> list: """Get all facets that a composition falls into. Args: @@ -900,7 +908,7 @@ def get_decomp_and_phase_separation_energy( return self.get_decomp_and_e_above_hull(entry, allow_negative=True, **kwargs) # take entries with negative e_form and different compositions as competing entries - competing_entries = {c for c in compare_entries if id(c) not in same_comp_mem_ids} + competing_entries: set[PDEntry] = {c for c in compare_entries if id(c) not in same_comp_mem_ids} # NOTE SLSQP optimizer doesn't scale well for > 300 competing entries. if len(competing_entries) > space_limit and not stable_only: @@ -927,7 +935,7 @@ def get_decomp_and_phase_separation_energy( stacklevel=2, ) - decomp = _get_slsqp_decomp(entry.composition, competing_entries, tols, maxiter) + decomp = _get_slsqp_decomp(entry.composition, list(competing_entries), tols, maxiter) # find the minimum alternative formation energy for the decomposition decomp_enthalpy = np.sum([c.energy_per_atom * amt for c, amt in decomp.items()]) @@ -936,7 +944,7 @@ def get_decomp_and_phase_separation_energy( return decomp, decomp_enthalpy - def get_phase_separation_energy(self, entry, **kwargs): + def get_phase_separation_energy(self, entry: PDEntry, **kwargs): """ Provides the energy to the convex hull for the given entry. For stable entries already in the phase diagram the algorithm provides the phase separation energy @@ -965,7 +973,7 @@ def get_phase_separation_energy(self, entry, **kwargs): """ return self.get_decomp_and_phase_separation_energy(entry, **kwargs)[1] - def get_composition_chempots(self, comp): + def get_composition_chempots(self, comp: Composition) -> dict[Element, float]: """Get the chemical potentials for all elements at a given composition. Args: @@ -977,7 +985,7 @@ def get_composition_chempots(self, comp): facet = self._get_facet_and_simplex(comp)[0] return self._get_facet_chempots(facet) - def get_all_chempots(self, comp): + def get_all_chempots(self, comp: Composition) -> dict[str, dict[Element, float]]: """Get chemical potentials at a given composition. Args: @@ -995,7 +1003,7 @@ def get_all_chempots(self, comp): return chempots - def get_transition_chempots(self, element): + def get_transition_chempots(self, element: Element) -> tuple[float, ...]: """Get the critical chemical potentials for an element in the Phase Diagram. @@ -1014,7 +1022,7 @@ def get_transition_chempots(self, element): chempots = self._get_facet_chempots(facet) critical_chempots.append(chempots[element]) - clean_pots = [] + clean_pots: list[float] = [] for c in sorted(critical_chempots): if len(clean_pots) == 0 or not math.isclose( c, clean_pots[-1], abs_tol=PhaseDiagram.numerical_tol, rel_tol=0 @@ -1023,7 +1031,11 @@ def get_transition_chempots(self, element): clean_pots.reverse() return tuple(clean_pots) - def get_critical_compositions(self, comp1, comp2): + def get_critical_compositions( + self, + comp1: Composition, + comp2: Composition, + ) -> list[Composition]: """Get the critical compositions along the tieline between two compositions. I.e. where the decomposition products change. The endpoints are also returned. @@ -1083,7 +1095,12 @@ def get_critical_compositions(self, comp1, comp2): return [Composition((elem, val) for elem, val in zip(pd_els, m, strict=True)) for m in cs] - def get_element_profile(self, element, comp, comp_tol=1e-5): + def get_element_profile( + self, + element: Element, + comp: Composition, + comp_tol: float = 1e-5, + ) -> list[dict[str, Any]]: """ Provides the element evolution data for a composition. For example, can be used to analyze Li conversion voltages by varying mu_Li and looking at the phases @@ -1131,7 +1148,10 @@ def get_element_profile(self, element, comp, comp_tol=1e-5): return evolution def get_chempot_range_map( - self, elements: Sequence[Element], referenced: bool = True, joggle: bool = True + self, + elements: Sequence[Element], + referenced: bool = True, + joggle: bool = True, ) -> dict[Element, list[Simplex]]: """Get a chemical potential range map for each stable entry. @@ -1184,7 +1204,12 @@ def get_chempot_range_map( return chempot_ranges - def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): + def getmu_vertices_stability_phase( + self, + target_comp: Composition, + dep_elt: Element, + tol_en: float = 1e-2, + ) -> list[dict[Element, float]] | None: """Get a set of chemical potentials corresponding to the vertices of the simplex in the chemical potential phase diagram. The simplex is built using all elements in the target_composition @@ -1218,7 +1243,7 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): if elem.composition.reduced_composition == target_comp.reduced_composition: multiplier = elem.composition[dep_elt] / target_comp[dep_elt] ef = elem.energy / multiplier - all_coords = [] + all_coords: list[dict] = [] for simplex in chempots: for v in simplex._coords: elements = [elem for elem in self.elements if elem != dep_elt] @@ -1242,7 +1267,11 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2): return all_coords return None - def get_chempot_range_stability_phase(self, target_comp, open_elt): + def get_chempot_range_stability_phase( + self, + target_comp: Composition, + open_elt: Element, + ) -> dict[Element, tuple[float, float]]: """Get a set of chemical potentials corresponding to the max and min chemical potential of the open element for a given composition. It is quite common to have for instance a ternary oxide (e.g., ABO3) for @@ -1302,14 +1331,14 @@ def get_plot( ternary_style: Literal["2d", "3d"] = "2d", label_stable: bool = True, label_unstable: bool = True, - ordering: Sequence[str] | None = None, - energy_colormap=None, + ordering: Sequence[Literal["Up", "Left", "Right"]] | None = None, + energy_colormap: str | Colormap | None = None, process_attributes: bool = False, - ax: plt.Axes = None, + ax: plt.Axes | None = None, label_uncertainties: bool = False, fill: bool = True, **kwargs, - ): + ) -> go.Figure | plt.Axes: """ Convenient wrapper for PDPlotter. Initializes a PDPlotter object and calls get_plot() with provided combined arguments. @@ -1464,7 +1493,7 @@ class CompoundPhaseDiagram(PhaseDiagram): """ # Tolerance for determining if amount of a composition is positive. - amount_tol = 1e-5 + amount_tol: ClassVar[float] = 1e-5 def __init__(self, entries, terminal_compositions, normalize_terminal_compositions=True): """Initialize a CompoundPhaseDiagram. @@ -1605,6 +1634,19 @@ class PatchedPhaseDiagram(PhaseDiagram): These are entries corresponding to the lowest energy element entries for simple compositional phase diagrams. elements (list[Element]): List of elements in the phase diagram. + + NOTE following methods are inherited unchanged from `PhaseDiagram`: + - __repr__ + - all_entries_hulldata + - unstable_entries + - stable_entries + - get_form_energy + - get_form_energy_per_atom + - get_hull_energy + - get_e_above_hull + - get_decomp_and_e_above_hull + - get_decomp_and_phase_separation_energy + - get_phase_separation_energ """ def __init__( @@ -1780,19 +1822,6 @@ def remove_redundant_spaces(spaces, keep_all_spaces=False): return result - # NOTE following methods are inherited unchanged from PhaseDiagram: - # __repr__, - # all_entries_hulldata, - # unstable_entries, - # stable_entries, - # get_form_energy(), - # get_form_energy_per_atom(), - # get_hull_energy(), - # get_e_above_hull(), - # get_decomp_and_e_above_hull(), - # get_decomp_and_phase_separation_energy(), - # get_phase_separation_energy() - def get_pd_for_entry(self, entry: Entry | Composition) -> PhaseDiagram: """Get the possible phase diagrams for an entry. @@ -1937,7 +1966,14 @@ class ReactionDiagram: an electrolyte and an electrode. """ - def __init__(self, entry1, entry2, all_entries, tol: float = 1e-4, float_fmt="%.4f"): + def __init__( + self, + entry1: ComputedEntry, + entry2: ComputedEntry, + all_entries: list[ComputedEntry], + tol: float = 1e-4, + float_fmt: str = "%.4f", + ) -> None: """ Args: entry1 (ComputedEntry): Entry for 1st component. Note that @@ -2056,7 +2092,7 @@ def fmt(fl): self.all_entries = all_entries self.pd = pd - def get_compound_pd(self): + def get_compound_pd(self) -> CompoundPhaseDiagram: """Get the CompoundPhaseDiagram object, which can then be used for plotting. @@ -2102,11 +2138,11 @@ def get_facets(qhull_data: ArrayLike, joggle: bool = False) -> ConvexHull: def _get_slsqp_decomp( - comp, - competing_entries, - tols=(1e-8,), - maxiter=1000, -): + comp: Composition, + competing_entries: Sequence[PDEntry], + tols: Sequence[float] = (1e-8,), + maxiter: int = 1000, +) -> dict: """Find the amounts of competing compositions that minimize the energy of a given composition. @@ -2245,20 +2281,22 @@ def get_plot( self, label_stable: bool = True, label_unstable: bool = True, - ordering: Sequence[str] | None = None, - energy_colormap=None, + # `matplotlib` only + ordering: Sequence[Literal["Up", "Left", "Right"]] | None = None, + energy_colormap: str | Colormap | None = None, process_attributes: bool = False, - ax: plt.Axes = None, + ax: plt.Axes | None = None, + # `plotly` only label_uncertainties: bool = False, fill: bool = True, highlight_entries: Collection[PDEntry] | None = None, - ) -> go.Figure | plt.Axes: + ) -> go.Figure | plt.Axes | None: """ Args: label_stable: Whether to label stable compounds. label_unstable: Whether to label unstable compounds. - ordering: Ordering of vertices, given as a list ['Up', - 'Left','Right'] (matplotlib only). + ordering: Ordering of vertices, given as a list ["Up", + "Left", "Right"] (matplotlib only). energy_colormap: Colormap for coloring energy (matplotlib only). process_attributes: Whether to process the attributes (matplotlib only). ax: Existing matplotlib Axes object if plotting multiple phase diagrams @@ -2358,7 +2396,13 @@ def write_image(self, stream: str | StringIO, image_format: str = "svg", **kwarg fig = self.get_plot(**kwargs) fig.write_image(stream, format=image_format) - def plot_element_profile(self, element, comp, show_label_index=None, xlim=5): + def plot_element_profile( + self, + element: Element, + comp: Composition, + show_label_index: list[int] | None = None, + xlim: float = 5, + ) -> plt.Axes: """ Draw the element profile plot for a composition varying different chemical potential of an element. @@ -2423,7 +2467,7 @@ def plot_element_profile(self, element, comp, show_label_index=None, xlim=5): return ax - def plot_chempot_range_map(self, elements, referenced=True) -> None: + def plot_chempot_range_map(self, elements: Sequence[Element], referenced: bool = True) -> None: """ Plot the chemical potential range _map using matplotlib. Currently works only for 3-component PDs. This shows the plot but does not return it. @@ -2441,7 +2485,7 @@ class (pymatgen.analysis.chempot_diagram). """ self.get_chempot_range_map_plot(elements, referenced=referenced).show() - def get_chempot_range_map_plot(self, elements, referenced=True): + def get_chempot_range_map_plot(self, elements: Sequence[Element], referenced: bool = True) -> plt.Axes: """Get a plot of the chemical potential range _map. Currently works only for 3-component PDs. @@ -2467,7 +2511,7 @@ class (pymatgen.analysis.chempot_diagram). for entry, lines in chempot_ranges.items(): comp = entry.composition center_x = center_y = 0 - coords = [] + coords: list[list] = [] contain_zero = any(comp.get_atomic_fraction(el) == 0 for el in elements) is_boundary = (not contain_zero) and sum(comp.get_atomic_fraction(el) for el in elements) == 1 for line in lines: @@ -2545,7 +2589,7 @@ class (pymatgen.analysis.chempot_diagram). plt.tight_layout() return ax - def get_contour_pd_plot(self): + def get_contour_pd_plot(self) -> plt.Axes: """ Plot a contour phase diagram plot, where phase triangles are colored according to degree of instability by interpolation. Currently only @@ -2581,7 +2625,7 @@ def get_contour_pd_plot(self): @property @lru_cache(1) # noqa: B019 - def pd_plot_data(self): + def pd_plot_data(self) -> tuple[list, dict, dict]: """ Plotting data for phase diagram. Cached for repetitive calls. @@ -2590,12 +2634,12 @@ def pd_plot_data(self): Returns: A tuple containing three objects (lines, stable_entries, unstable_entries): - - lines is a list of list of coordinates for lines in the PD. - - stable_entries is a dict of {coordinates : entry} for each stable node - in the phase diagram. (Each coordinate can only have one - stable phase) - - unstable_entries is a dict of {entry: coordinates} for all unstable - nodes in the phase diagram. + - lines: a list of list of coordinates for lines in the PD. + - stable_entries: a dict of {coordinates: entry} for each stable node + in the phase diagram. (Each coordinate can only have one + stable phase) + - unstable_entries: a dict of {entry: coordinates} for all unstable + nodes in the phase diagram. """ pd = self._pd entries = pd.qhull_entries @@ -2645,7 +2689,7 @@ def pd_plot_data(self): return lines, stable_entries, unstable_entries - def _create_plotly_figure_layout(self, label_stable=True): + def _create_plotly_figure_layout(self, label_stable: bool = True) -> dict[str, Any]: """ Creates layout for plotly phase diagram figure and updates with figure annotations. @@ -2689,7 +2733,7 @@ def _create_plotly_figure_layout(self, label_stable=True): return layout - def _create_plotly_lines(self): + def _create_plotly_lines(self) -> go.Scatter | go.Scatterternary | go.Scatter3d | None: """ Create Plotly scatter plots containing line traces of phase diagram facets. @@ -2697,12 +2741,12 @@ def _create_plotly_lines(self): Either a go.Scatter (binary), go.Scatterternary (ternary_2d), or go.Scatter3d plot (ternary_3d, quaternary) """ - line_plot = None + x, y, z, energies = [], [], [], [] pd = self._pd - plot_args = { + plot_args: dict[str, Any] = { "mode": "lines", "hoverinfo": "none", "line": {"color": "black", "width": 4.0}, @@ -2740,18 +2784,21 @@ def _create_plotly_lines(self): z += [*line[2], None] if self._dim == 2: - line_plot = go.Scatter(x=x, y=y, **plot_args) - elif self._dim == 3 and self.ternary_style == "2d": - line_plot = go.Scatterternary(a=x, b=y, c=z, **plot_args) - elif self._dim == 3 and self.ternary_style == "3d": - line_plot = go.Scatter3d(x=y, y=x, z=z, **plot_args) - elif self._dim == 4: + return go.Scatter(x=x, y=y, **plot_args) + + if self._dim == 3: + if self.ternary_style == "2d": + return go.Scatterternary(a=x, b=y, c=z, **plot_args) + if self.ternary_style == "3d": + return go.Scatter3d(x=y, y=x, z=z, **plot_args) + + if self._dim == 4: plot_args["line"]["width"] = 1.5 - line_plot = go.Scatter3d(x=x, y=y, z=z, **plot_args) + return go.Scatter3d(x=x, y=y, z=z, **plot_args) - return line_plot + return None - def _create_plotly_fill(self): + def _create_plotly_fill(self) -> list[go.Mesh3d]: """ Creates shaded mesh traces for coloring the hull. @@ -2890,7 +2937,7 @@ def _create_plotly_fill(self): return traces - def _create_plotly_stable_labels(self, label_stable=True): + def _create_plotly_stable_labels(self, label_stable: bool = True) -> go.Scatter | go.Scatter3d: """ Creates a (hidable) scatter trace containing labels of stable phases. Contains some functionality for creating sensible label positions. This method @@ -2953,7 +3000,7 @@ def _create_plotly_stable_labels(self, label_stable=True): formula = comp.reduced_formula text.append(htmlify(formula)) - visible = True + visible: str | bool = True if not label_stable or self._dim == 4: visible = "legendonly" @@ -2977,7 +3024,7 @@ def _create_plotly_stable_labels(self, label_stable=True): return stable_labels_plot - def _create_plotly_element_annotations(self): + def _create_plotly_element_annotations(self) -> list[dict] | None: """ Creates terminal element annotations for Plotly phase diagrams. This method does not apply to ternary_2d plots. @@ -3054,20 +3101,29 @@ def _create_plotly_element_annotations(self): return annotations_list - def _create_plotly_markers(self, highlight_entries=None, label_uncertainties=False): + def _create_plotly_markers( + self, + highlight_entries: Collection[PDEntry] | None = None, + label_uncertainties: bool = False, + ) -> tuple: """ - Creates stable and unstable marker plots for overlaying on the phase diagram. + Creates stable, unstable and highlight marker plots for overlaying on the phase diagram. Returns: tuple[go.Scatter]: Plotly Scatter objects (unary, binary), go.Scatterternary(ternary_2d), - or go.Scatter3d (ternary_3d, quaternary) objects in order: (stable markers, unstable markers) + or go.Scatter3d (ternary_3d, quaternary) objects in order: (stable, unstable and highlight markers) """ - def get_marker_props(coords, entries): + def get_marker_props(coords, entries) -> dict[str, Any]: """Get marker locations, hovertext, and error bars from pd_plot_data.""" - x, y, z, texts, energies, uncertainties = [], [], [], [], [], [] - - is_stable = [entry in self._pd.stable_entries for entry in entries] + x: list[float] = [] + y: list[float] = [] + z: list[float] = [] + texts: list[str] = [] + energies: list[float] = [] + uncertainties: list[float] = [] + + is_stable: list[bool] = [entry in self._pd.stable_entries for entry in entries] for coord, entry, stable in zip(coords, entries, is_stable, strict=True): energy = round(self._pd.get_form_energy_per_atom(entry), 3) @@ -3082,10 +3138,12 @@ def get_marker_props(coords, entries): formula = comp.reduced_formula clean_formula = htmlify(formula) label = f"{clean_formula} ({entry_id})
Formation energy: {energy} eV/atom
" + if not stable: - e_above_hull = round(self._pd.get_e_above_hull(entry), 3) - if e_above_hull > self.show_unstable: + e_above_hull = self._pd.get_e_above_hull(entry) + if e_above_hull is None or e_above_hull > self.show_unstable: continue + e_above_hull = round(e_above_hull, 3) label += f" Energy Above Hull: ({e_above_hull:+} eV/atom)" energies.append(e_above_hull) else: @@ -3106,6 +3164,7 @@ def get_marker_props(coords, entries): _cartesian_positions = [x, y, z] _cartesian_positions[axis].append(entry.composition[el]) label += f"
{el}: {round(entry.composition[el] / total_sum_el, 6)}" + elif self._dim == 3 and self.ternary_style == "3d": x.append(coord[0]) y.append(coord[1]) @@ -3117,6 +3176,7 @@ def get_marker_props(coords, entries): ) for el, _axis in zip(self._pd.elements, range(self._dim), strict=True): label += f"
{el}: {round(entry.composition[el] / total_sum_el, 6)}" + elif self._dim == 4: x.append(coord[0]) y.append(coord[1]) @@ -3128,6 +3188,7 @@ def get_marker_props(coords, entries): ) for el, _axis in zip(self._pd.elements, range(self._dim), strict=True): label += f"
{el}: {round(entry.composition[el] / total_sum_el, 6)}" + else: x.append(coord[0]) y.append(coord[1]) @@ -3146,10 +3207,15 @@ def get_marker_props(coords, entries): if highlight_entries is None: highlight_entries = [] - stable_coords, stable_entries = [], [] - unstable_coords, unstable_entries = [], [] - highlight_coords, highlight_ents = [], [] + stable_coords: list[Sequence[float]] = [] + unstable_coords: list[Sequence[float]] = [] + highlight_coords: list[Sequence[float]] = [] + stable_entries: list[PDEntry] = [] + unstable_entries: list[PDEntry] = [] + highlight_ents: list[PDEntry] = [] + + # Stable entries for coord, entry in zip(self.pd_plot_data[1], self.pd_plot_data[1].values(), strict=True): if entry in highlight_entries: highlight_coords.append(coord) @@ -3158,6 +3224,7 @@ def get_marker_props(coords, entries): stable_coords.append(coord) stable_entries.append(entry) + # Unstable entries for coord, entry in zip(self.pd_plot_data[2].values(), self.pd_plot_data[2], strict=True): if entry in highlight_entries: highlight_coords.append(coord) @@ -3487,29 +3554,31 @@ def get_marker_props(coords, entries): highlight_marker_plot = None - if self._dim in [1, 2]: + if self._dim in {1, 2}: stable_marker_plot, unstable_marker_plot = ( - go.Scatter(**markers) for markers in [stable_markers, unstable_markers] + go.Scatter(**markers) for markers in (stable_markers, unstable_markers) ) if highlight_entries: highlight_marker_plot = go.Scatter(**highlight_markers) + elif self._dim == 3 and self.ternary_style == "2d": stable_marker_plot, unstable_marker_plot = ( - go.Scatterternary(**markers) for markers in [stable_markers, unstable_markers] + go.Scatterternary(**markers) for markers in (stable_markers, unstable_markers) ) if highlight_entries: highlight_marker_plot = go.Scatterternary(**highlight_markers) + else: stable_marker_plot, unstable_marker_plot = ( - go.Scatter3d(**markers) for markers in [stable_markers, unstable_markers] + go.Scatter3d(**markers) for markers in (stable_markers, unstable_markers) ) if highlight_entries: highlight_marker_plot = go.Scatter3d(**highlight_markers) return stable_marker_plot, unstable_marker_plot, highlight_marker_plot - def _create_plotly_uncertainty_shading(self, stable_marker_plot): + def _create_plotly_uncertainty_shading(self, stable_marker_plot: go.Scatter) -> go.Scatter: """ Creates shaded uncertainty region for stable entries. Currently only works for binary (dim=2) phase diagrams. @@ -3540,7 +3609,7 @@ def _create_plotly_uncertainty_shading(self, stable_marker_plot): outline = points[:, :2].copy() outline[:, 1] += points[:, 2] - last = -1 + last: int | None = -1 if transformed: last = None # allows for uncertainty in terminal compounds @@ -3562,7 +3631,7 @@ def _create_plotly_uncertainty_shading(self, stable_marker_plot): return uncertainty_plot - def _create_plotly_ternary_support_lines(self): + def _create_plotly_ternary_support_lines(self) -> go.Scatter3d: """ Creates support lines which aid in seeing the ternary hull in three dimensions. @@ -3570,7 +3639,7 @@ def _create_plotly_ternary_support_lines(self): Returns: go.Scatter3d plot of support lines for ternary phase diagram. """ - stable_entry_coords = dict(map(reversed, self.pd_plot_data[1].items())) + stable_entry_coords: dict = {v: k for k, v in self.pd_plot_data[1].items()} elem_coords = [stable_entry_coords[entry] for entry in self._pd.el_refs.values()] @@ -3599,20 +3668,17 @@ def _create_plotly_ternary_support_lines(self): def _get_matplotlib_2d_plot( self, - label_stable=True, - label_unstable=True, - ordering=None, - energy_colormap=None, - vmin_mev=-60.0, - vmax_mev=60.0, - show_colorbar=True, - process_attributes=False, - ax: plt.Axes = None, - ): - """Show the plot using matplotlib. - - Imports are done within the function as matplotlib is no longer the default. - """ + label_stable: bool = True, + label_unstable: bool = True, + ordering: Sequence[Literal["Up", "Left", "Right"]] | None = None, + energy_colormap: str | Colormap | None = None, + vmin_mev: float = -60.0, + vmax_mev: float = 60.0, + show_colorbar: bool = True, + process_attributes: bool = False, + ax: plt.Axes | None = None, + ) -> plt.Axes: + """Show the plot using matplotlib.""" ax = ax or pretty_plot(8, 6) if ordering is None: @@ -3808,7 +3874,11 @@ def _get_matplotlib_2d_plot( plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07) return ax - def _get_matplotlib_3d_plot(self, label_stable=True, ax: plt.Axes = None): + def _get_matplotlib_3d_plot( + self, + label_stable: bool = True, + ax: plt.Axes | None = None, + ) -> plt.Axes: """Show the plot using matplotlib. Args: @@ -3851,23 +3921,23 @@ def _get_matplotlib_3d_plot(self, label_stable=True, ax: plt.Axes = None): return ax -def uniquelines(q): +def uniquelines(q: list[NDArray[int]]) -> set[tuple[int, int]]: """ Given all the facets, convert it into a set of unique lines. Specifically used for converting convex hull facets into line pairs of coordinates. Args: q: A 2-dim sequence, where each row represents a facet. e.g. - [[1,2,3],[3,6,7],...] + [[1, 2, 3], [3, 6, 7], ...] Returns: setoflines: - A set of tuple of lines. e.g. ((1,2), (1,3), (2,3), ....) + A set of tuple of lines. e.g. ((1, 2), (1, 3), ...) """ return {tuple(sorted(line)) for facets in q for line in itertools.combinations(facets, 2)} -def triangular_coord(coord): +def triangular_coord(coord: ArrayLike) -> np.ndarray: """ Convert a 2D coordinate into a triangle-based coordinate system for a prettier phase diagram. @@ -3884,7 +3954,7 @@ def triangular_coord(coord): return result.transpose() -def tet_coord(coord): +def tet_coord(coord: ArrayLike) -> np.ndarray: """ Convert a 3D coordinate into a tetrahedron based coordinate system for a prettier phase diagram. @@ -3906,7 +3976,12 @@ def tet_coord(coord): return result.transpose() -def order_phase_diagram(lines, stable_entries, unstable_entries, ordering): +def order_phase_diagram( + lines: list, + stable_entries: dict[Any, PDEntry], + unstable_entries: dict[PDEntry, Any], + ordering: Sequence[Literal["Up", "Left", "Right"]], +) -> tuple[list, dict[Any, PDEntry], dict[PDEntry, Any]]: """ Orders the entries (their coordinates) in a phase diagram plot according to the user specified ordering. @@ -3925,12 +4000,11 @@ def order_phase_diagram(lines, stable_entries, unstable_entries, ordering): Returns: tuple[list, dict, dict]: - - new_lines is a list of list of coordinates for lines in the PD. - - new_stable_entries is a {coordinate: entry} for each stable node + - a list of list of coordinates for lines in the PD. + - a {coordinate: entry} for each stable node in the phase diagram. (Each coordinate can only have one stable phase) - - new_unstable_entries is a {entry: coordinates} for all unstable - nodes in the phase diagram. + - a {entry: coordinates} for all unstable nodes in the phase diagram. """ yup = -1000.0 xleft = 1000.0 @@ -3939,16 +4013,16 @@ def order_phase_diagram(lines, stable_entries, unstable_entries, ordering): nameup = "" nameleft = "" nameright = "" - for coord in stable_entries: + for coord, entry in stable_entries.items(): if coord[0] > xright: xright = coord[0] - nameright = stable_entries[coord].name + nameright = entry.name if coord[0] < xleft: xleft = coord[0] - nameleft = stable_entries[coord].name + nameleft = entry.name if coord[1] > yup: yup = coord[1] - nameup = stable_entries[coord].name + nameup = entry.name if (nameup not in ordering) or (nameright not in ordering) or (nameleft not in ordering): raise ValueError( From de71e2aa38620b2964c57bf1f9486a488e5a9b10 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Sat, 4 Oct 2025 12:53:15 +0200 Subject: [PATCH 4/5] more explicit handling of dimension --- src/pymatgen/analysis/phase_diagram.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 0afe6306cc8..312b07719d6 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -2313,10 +2313,14 @@ def get_plot( Returns: go.Figure | plt.Axes: Plotly figure or matplotlib axes object depending on backend. """ - fig = None - data = [] + if self._dim not in {1, 2, 3, 4}: + raise ValueError( + f"Plotting is only supported for unary/binary/ternary/quaternary phase diagrams — got {self._dim}D " + ) if self.backend == "plotly": + data: list = [] + if self._dim != 1: data.append(self._create_plotly_lines()) @@ -2334,7 +2338,7 @@ def get_plot( if self._dim != 1 and not (self._dim == 3 and self.ternary_style == "2d"): data.append(self._create_plotly_stable_labels(label_stable)) - if fill and self._dim in [3, 4]: + if fill and self._dim in {3, 4}: data.extend(self._create_plotly_fill()) data.extend([stable_marker_plot, unstable_marker_plot]) @@ -2346,9 +2350,11 @@ def get_plot( fig.layout = self._create_plotly_figure_layout() fig.update_layout(coloraxis_colorbar={"yanchor": "top", "y": 0.05, "x": 1}) - elif self.backend == "matplotlib": - if self._dim <= 3: - fig = self._get_matplotlib_2d_plot( + return fig + + if self.backend == "matplotlib": + if self._dim in {1, 2, 3}: + return self._get_matplotlib_2d_plot( label_stable, label_unstable, ordering, @@ -2356,10 +2362,10 @@ def get_plot( ax=ax, process_attributes=process_attributes, ) - elif self._dim == 4: - fig = self._get_matplotlib_3d_plot(label_stable, ax=ax) + if self._dim == 4: + return self._get_matplotlib_3d_plot(label_stable, ax=ax) - return fig + return None def show(self, *args, **kwargs) -> None: """ From c9e15e4f44d6697037c97836552a9d3674abb583 Mon Sep 17 00:00:00 2001 From: "Haoyu (Daniel)" Date: Fri, 17 Oct 2025 10:53:55 +0200 Subject: [PATCH 5/5] FIX: PD Plotter only show lowest energy for unstable composition --- src/pymatgen/analysis/phase_diagram.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index 312b07719d6..83edba1fcbf 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -3214,11 +3214,9 @@ def get_marker_props(coords, entries) -> dict[str, Any]: highlight_entries = [] stable_coords: list[Sequence[float]] = [] - unstable_coords: list[Sequence[float]] = [] highlight_coords: list[Sequence[float]] = [] stable_entries: list[PDEntry] = [] - unstable_entries: list[PDEntry] = [] highlight_ents: list[PDEntry] = [] # Stable entries @@ -3230,14 +3228,23 @@ def get_marker_props(coords, entries) -> dict[str, Any]: stable_coords.append(coord) stable_entries.append(entry) - # Unstable entries + # Unstable entries (lowest energy only per composition) + min_unstable: dict[str, tuple[Sequence[float], PDEntry]] = {} + for coord, entry in zip(self.pd_plot_data[2].values(), self.pd_plot_data[2], strict=True): if entry in highlight_entries: highlight_coords.append(coord) highlight_ents.append(entry) - else: - unstable_coords.append(coord) - unstable_entries.append(entry) + continue + + formula = entry.composition.reduced_formula + e_above_hull = self._pd.get_e_above_hull(entry) + + if formula not in min_unstable or e_above_hull < self._pd.get_e_above_hull(min_unstable[formula][1]): + min_unstable[formula] = (coord, entry) + + unstable_coords = [coord for coord, _ in min_unstable.values()] + unstable_entries = [entry for _, entry in min_unstable.values()] stable_props = get_marker_props(stable_coords, stable_entries) unstable_props = get_marker_props(unstable_coords, unstable_entries)