diff --git a/src/pymatgen/analysis/structure_prediction/dopant_predictor.py b/src/pymatgen/analysis/structure_prediction/dopant_predictor.py index 760b6e9e268..8252adab5be 100644 --- a/src/pymatgen/analysis/structure_prediction/dopant_predictor.py +++ b/src/pymatgen/analysis/structure_prediction/dopant_predictor.py @@ -128,7 +128,7 @@ def get_dopants_from_shannon_radii(bonded_structure, num_dopants=5, match_oxi_si def _get_dopants(substitutions, num_dopants, match_oxi_sign) -> dict: """Utility method to get n- and p-type dopants from a list of substitutions.""" dopants = {k: [] for k in ("n_type", "p_type")} - for k in dopants: # noqa: PLC0206 + for k, dop in dopants.items(): for pred in substitutions: if ( pred["dopant_species"].oxi_state > pred["original_species"].oxi_state @@ -138,8 +138,8 @@ def _get_dopants(substitutions, num_dopants, match_oxi_sign) -> dict: not match_oxi_sign or np.sign(pred["dopant_species"].oxi_state) == np.sign(pred["original_species"].oxi_state) ): - dopants[k].append(pred) - if len(dopants[k]) == num_dopants: + dop.append(pred) + if len(dop) == num_dopants: break return dopants diff --git a/src/pymatgen/cli/pmg.py b/src/pymatgen/cli/pmg.py index df397b1c079..e859045f8f4 100755 --- a/src/pymatgen/cli/pmg.py +++ b/src/pymatgen/cli/pmg.py @@ -63,11 +63,11 @@ def format_lists(v): ["---------------", "", ""], ] output += [ - ( # type: ignore[misc] + [ k, format_lists(diff["Same"][k]), format_lists(diff["Same"][k]), - ) + ] for k in sorted(diff["Same"]) if k != "SYSTEM" ] @@ -77,7 +77,7 @@ def format_lists(v): ["----------------", "", ""], ] output += [ - [ # type: ignore[misc] + [ k, format_lists(diff["Different"][k]["INCAR1"]), format_lists(diff["Different"][k]["INCAR2"]), diff --git a/src/pymatgen/io/vasp/inputs.py b/src/pymatgen/io/vasp/inputs.py index 086c75a80cb..3c4d910f8bc 100644 --- a/src/pymatgen/io/vasp/inputs.py +++ b/src/pymatgen/io/vasp/inputs.py @@ -961,13 +961,39 @@ def from_str(cls, string: str) -> Self: Returns: Incar object """ + string = "\n".join([ln.split("#", 1)[0].split("!", 1)[0].rstrip() for ln in string.splitlines()]) + params: dict[str, Any] = {} - for line in clean_lines(string.splitlines()): - for sline in line.split(";"): - if match := re.match(r"(\w+)\s*=\s*(.*)", sline.strip()): - key: str = match[1].strip() - val: str = match[2].strip() - params[key] = cls.proc_val(key, val) + + # Handle line continuations (\) + string = re.sub(r"\\\s*\n", " ", string) + + # Regex pattern to find all valid "key = value" assignments at once + pattern = re.compile( + r""" + (?P\w+) # Key (e.g. ENCUT) + \s*=\s* # Equals sign and optional spaces + (?: # Non-capturing group for the value + " # Opening quote + (?P.*?) # Capture everything inside (non-greedy) + [ \t]*" # Allow trailing spaces/tabs before closing quote + | # OR + (?P[^#!;\n]*) # Unquoted value (stops before comment/separator) + ) + """, + re.VERBOSE | re.DOTALL, + ) + + # Find all matches in the entire string + for match in pattern.finditer(string): + key = match.group("key") + val = match.group("qval") if match.group("qval") is not None else (match.group("val") or "").strip() + + if not val: + continue + + params[key] = cls.proc_val(key, val) + return cls(params) @staticmethod @@ -1038,7 +1064,7 @@ def proc_val(key: str, val: str) -> list | bool | float | int | str: ) lower_str_keys = ("ML_MODE",) # String keywords to read "as is" (no case transformation, only stripped) - as_is_str_keys = ("SYSTEM",) + as_is_str_keys = ("SYSTEM", "WANNIER90_WIN") def smart_int_or_float_bool(str_: str) -> float | int | bool: """Determine whether a string represents an integer or a float.""" @@ -1117,7 +1143,7 @@ def diff(self, other: Self) -> dict[str, dict[str, Any]]: {"Same" : parameters_that_are_the_same, "Different": parameters_that_are_different} Note that the parameters are return as full dictionaries of values. E.g. {"ISIF":3} """ - similar_params = {} + same_params = {} different_params = {} for k1, v1 in self.items(): if k1 not in other: @@ -1125,13 +1151,13 @@ def diff(self, other: Self) -> dict[str, dict[str, Any]]: elif v1 != other[k1]: different_params[k1] = {"INCAR1": v1, "INCAR2": other[k1]} else: - similar_params[k1] = v1 + same_params[k1] = v1 for k2, v2 in other.items(): - if k2 not in similar_params and k2 not in different_params and k2 not in self: + if k2 not in same_params and k2 not in different_params and k2 not in self: different_params[k2] = {"INCAR1": None, "INCAR2": v2} - return {"Same": similar_params, "Different": different_params} + return {"Same": same_params, "Different": different_params} def check_params(self) -> None: """Check INCAR for invalid tags or values. diff --git a/tests/io/vasp/test_inputs.py b/tests/io/vasp/test_inputs.py index 3d2f73bbb98..f1ecf816c37 100644 --- a/tests/io/vasp/test_inputs.py +++ b/tests/io/vasp/test_inputs.py @@ -886,6 +886,92 @@ def test_write(self): incar = Incar.from_file(tmp_file) assert incar == self.incar + def test_from_str_comment_handling(self): + incar_str = r""" + # A = 0 + ! B=1 + SIGMA = 0.05 # random comment (known float tag) + EDIFF = 1e-6 ! another comment (known float tag) + ALGO = Normal # comment (unknown tag -> inferred as str) + GGA = PE ! comment (unknown tag -> inferred as str) + """ + incar = Incar.from_str(incar_str) + + assert set(incar.keys()) == {"SIGMA", "EDIFF", "ALGO", "GGA"} + assert incar["SIGMA"] == approx(0.05) + assert incar["EDIFF"] == approx(1e-6) + assert incar["ALGO"] == "Normal" + assert incar["GGA"] == "Pe" + + def test_from_str_semicolon_separated_statements(self): + # Test interaction between semicolon and comment + incar_str = r""" + ENMAX = 400; ALGO = Fast ! A = 0 + ENCUT = 500; ISMEAR = 0 # B=1 + PREC = Accurate ; LREAL = Auto ! precision and projection scheme + IBRION = 2; ISIF = 3; NSW = 100 # three statements in one line + """ + incar = Incar.from_str(incar_str) + + assert set(incar.keys()) == { + "ENMAX", + "ALGO", + "ENCUT", + "ISMEAR", + "PREC", + "LREAL", + "IBRION", + "ISIF", + "NSW", + } + + assert incar["ENMAX"] == 400 + assert incar["ALGO"] == "Fast" + assert incar["ENCUT"] == 500 + assert incar["ISMEAR"] == 0 + assert incar["PREC"] == "Accurate" + assert incar["LREAL"] == "Auto" + assert incar["IBRION"] == 2 + assert incar["ISIF"] == 3 + assert incar["NSW"] == 100 + + def test_from_str_line_continuation_with_backslash(self): + # Test line continuation with backslash + incar_str = r""" + ALGO = Normal # \ This backslash should be ignored + ENMAX = 200 ! \ This backslash should be ignored + MAGMOM = 0 0 1.0 0 0 -1.0 \ + 0 0 1.0 0 0 -1.0 \ + 6*0 + """ + incar = Incar.from_str(incar_str) + + assert set(incar.keys()) == {"ALGO", "ENMAX", "MAGMOM"} + assert incar["ALGO"] == "Normal" + assert incar["ENMAX"] == 200 + + assert incar["MAGMOM"] == [0, 0, 1.0, 0, 0, -1.0, 0, 0, 1.0, 0, 0, -1.0] + [0.0] * 6 + + def test_from_str_multiline_string(self): + incar_str = r""" + # Multi-line string with embedded comments + WANNIER90_WIN = "begin Projections # should NOT be capitalized + Fe:d ; Fe:p # comment inside string + End Projections ! random comment + " # comment after closing quote + """ + incar = Incar.from_str(incar_str) + + assert set(incar.keys()) == {"WANNIER90_WIN"} + + # Comments inside the string would be lost + assert ( + incar["WANNIER90_WIN"] + == """begin Projections + Fe:d ; Fe:p + End Projections""" + ) + def test_get_str(self): incar_str = self.incar.get_str(pretty=True, sort_keys=True) expected = """ALGO = Damped @@ -1003,6 +1089,7 @@ def test_types(self): def test_proc_types(self): assert Incar.proc_val("HELLO", "-0.85 0.85") == "-0.85 0.85" + # `ML_MODE` should always be lower case assert Incar.proc_val("ML_MODE", "train") == "train" assert Incar.proc_val("ML_MODE", "RUN") == "run" assert Incar.proc_val("ALGO", "fast") == "Fast"