diff --git a/bindings/python/libmsym/libmsym.py b/bindings/python/libmsym/libmsym.py index e70156f..4712c11 100644 --- a/bindings/python/libmsym/libmsym.py +++ b/bindings/python/libmsym/libmsym.py @@ -16,6 +16,7 @@ _lib = None +atomic_orbital_symbols = "spdfghik" @export class Error(Exception): def __init__(self, value, details=""): @@ -29,8 +30,10 @@ def __repr__(self): try: import numpy as np + import scipy.sparse as ssparse except ImportError: np = None + ssparse = None @export class SymmetryOperation(Structure): @@ -78,7 +81,8 @@ def __str__(self): power = "^" + str(self.power) axis = " around " + "[ {: >.3f}, {: >.3f}, {: >.3f}]".format(self.vector[0],self.vector[1],self.vector[2]) - return __name__ + "." + self.__class__.__name__ + "( " + self._names[self.type] + order + orientation + power + axis + ", conjugacy class: " + str(self.conjugacy_class) + " )" + return f"<{self.__class__.__name__}> ({self._names[self.type]}{order+orientation+power+axis}, conjugacy class: {self.conjugacy_class})" + def __repr__(self): return self.__str__() @@ -89,10 +93,11 @@ class Element(Structure): ("_v", c_double*3), ("charge", c_int), ("_name",c_char*4)] + equiv = None @property def coordinates(self): - return self._v[0:3] + return np.array(self._v[0:3]) @coordinates.setter def coordinates(self, coordinates): self._v = (c_double*3)(*coordinates) @@ -102,6 +107,115 @@ def name(self): @name.setter def name(self, name): self._name = name.encode('ascii') + + def __str__(self): + res = "{}[{}] {:.6f} {:.6f} {:.6f}".format(self.name, self.index, *self.coordinates) + if self.equiv is None: + return res + else: + return f"{res} from {self.equiv}" + + def __repr__(self): + return self.__str__() + + def __eq__(s, o): + if not isinstance(o, Element): + return False + return np.allclose(s.coordinates, o.coordinates) and s.name == o.name +@export +class EquivalenceSets(Structure): + _fields_ = [("_elements", POINTER(POINTER(Element))), + ("error", c_double), + ("length", c_int)] + + elements = [] + def _update_elements(self, elements): + addresses = [addressof(ele) for ele in elements] + self.elements = [] + for ele in self._elements[0:self.length]: + + pyEle = elements[addresses.index(addressof(ele.contents))] + pyEle.equiv = self + self.elements.append(pyEle) + + def __str__(self): + return f"<{self.__class__.__name__}> {self.elements[0].name} x {self.length}" + + def __repr__(self): + return self.__str__() + +class SALC_wf(object): + def __init__(self, coeff, partners, spec): + self.coeff = coeff + self.partners = partners + self.sparse = ssparse.csr_array(coeff) + self.spec = spec # symmetry species + self._nonzeros = self.sparse.count_nonzero() + + def __str__(self): + return f"<{self.__class__.__name__}> {self.spec.name}{self.coeff.shape} with {self._nonzeros} nonzeros and {len(self.partners)} parter functions" + + def __repr__(self): + return self.__str__() + + def print_coeff(self, basis_functions): + if self._nonzeros == 0: + print(self.spec, end="") + return + else: + print(self.spec) + + for k, vec in enumerate(self.sparse): + idx = vec.nonzero()[1] + print(f"{k+1:>3d} ", end ="") + for i in idx: + print(f" {basis_functions[i].comment:15s}", end="") + print("\n ", end ="") + for i in idx: + print(f"{vec[0, i]:>14.11f} ", end="") + print("") +@export +class Subgroup(Structure): + TYPE_Kh = 0, + TYPE_K = 1, + TYPE_Ci = 2, + TYPE_Cs = 3, + TYPE_Cn = 4, + TYPE_Cnh = 5, + TYPE_Cnv = 6, + TYPE_Dn = 7, + TYPE_Dnh = 8, + TYPE_Dnd = 9, + TYPE_Sn = 10, + TYPE_T = 11, + TYPE_Td = 12, + TYPE_Th = 13, + TYPE_O = 14, + TYPE_Oh = 15, + TYPE_I = 16, + TYPE_Ih = 17 + + primary_operations = None + symmetry_operations = [] + + @property + def name(self): + return self._name.decode() + + def _update_symmetry_operations(self, symmetry_operations): + addresses = [addressof(sop) for sop in symmetry_operations] + self.symmetry_operations = [symmetry_operations[addresses.index(addressof(sop.contents))] for sop in self._sops[0:self.n]] + if self._primary: + index = addresses.index(addressof(self._primary.contents)) + self.primary_operations = symmetry_operations[index] + +Subgroup._fields_ = [("type", c_int), + ("n", c_int), + ("order", c_int), + ("_primary", POINTER(SymmetryOperation)), + ("_sops", POINTER(POINTER(SymmetryOperation))), + ("generators", POINTER(Subgroup)*2), + ("_name",c_char*8)] class _RealSphericalHarmonic(Structure): _fields_ = [("n", c_int), @@ -118,12 +232,14 @@ class BasisFunction(Structure): ("_element", POINTER(Element)), ("_f", _BasisFunctionUnion), ("_name",c_char*8)] - - def __init__(self, element=None): + + _comment = None + def __init__(self, element=None, comment = None): if element == None: raise Error("Basis function requires an element") super().__init__() self.element = element + self._comment = comment def _set_element_pointer(self, element): self._element = pointer(element) @@ -134,11 +250,18 @@ def name(self): @name.setter def name(self, name): self._name = name.encode('ascii') + + @property + def comment(self): + if self._comment is None: + return self.name + else: + return self._comment @export class RealSphericalHarmonic(BasisFunction): - def __init__(self, element=None, n=0, l=0, m=0, name=""): - super().__init__(element=element) + def __init__(self, element=None, n=0, l=0, m=0, name="", comment = None): + super().__init__(element=element, comment = comment) self._type = 0 self._f._rsh.n = n self._f._rsh.l = l @@ -156,16 +279,24 @@ def n(self, n): def l(self): return self._f._rsh.l @l.setter - def l(self, n): + def l(self, l): self._f._rsh.n = l @property def m(self): return self._f._rsh.m @m.setter - def m(self, n): + def m(self, m): self._f._rsh.n = m + + def __str__(self): + name = self.comment.strip() + return f" {name} {self.n}{atomic_orbital_symbols[self.l]},m={self.m}" + + def __repr__(self): + return self.__str__() + class SALC(Structure): _fields_ = [("_d", c_int), ("_fl", c_int), @@ -177,7 +308,7 @@ class SALC(Structure): def _update_basis_functions(self, basis_function_addresses, basis): self.basis_functions = [basis[basis_function_addresses.index(addressof(p.contents))] for p in self._f[0:self._fl]] - + #@property #def partner_functions(self): # if self._pf_array is None: @@ -196,7 +327,16 @@ def partner_functions(self): return self._pf_array - + def __str__(self): + basis = self.basis_functions[0] + if isinstance(basis, RealSphericalHarmonic): + type_name = "realSHs" + else: + type_name = "orbitals" + return f"<{self.__class__.__name__}> from {self._fl} {type_name} with n={basis.n} and l={basis.l}" + + def __repr__(self): + return self.__str__() @export class SubrepresentationSpace(Structure): _fields_ = [("symmetry_species", c_int), @@ -210,7 +350,13 @@ def salcs(self): if self._salcarray is None: self._salcarray = self._salcs[0:self._salc_length] return self._salcarray + + def __str__(self): + return f"<{self.__class__.__name__}> for {self.symmetry_species}th symmetry species with {self._salc_length} SALCs" + def __repr__(self): + return self.__str__() + @export class PartnerFunction(Structure): _fields_ = [("index", c_int), @@ -234,6 +380,13 @@ def reducible(self): def name(self): return self._name.decode() + def __str__(self): + return f"<{self.__class__.__name__}> {self.name} ({self._d} partner functions each)" + + def __repr__(self): + return self.__str__() + + class _Thresholds(Structure): _fields_ = [("zero", c_double), ("geometry", c_double), @@ -281,7 +434,12 @@ def symmetry_species(self): self._symmetry_species = self._s[0:self._d] return self._symmetry_species - + + def __str__(self): + return f"<{self.__class__.__name__}> with {self._d} symmetry species" + + def __repr__(self): + return self.__str__() class _ReturnCode(c_int): SUCCESS = 0 @@ -381,6 +539,24 @@ def init(library_location=None): _lib.msymGetCharacterTable.restype = _ReturnCode _lib.msymGetCharacterTable.argtypes = [_Context, POINTER(POINTER(CharacterTable))] + _lib.msymGetCenterOfMass.restype = _ReturnCode + _lib.msymGetCenterOfMass.argtypes = [_Context, POINTER(c_double)] + + _lib.msymGetRadius.restype = _ReturnCode + _lib.msymGetRadius.argtypes = [_Context, POINTER(c_double)] + + _lib.msymGetSubgroups.restype = _ReturnCode + _lib.msymGetSubgroups.argtypes = [_Context, POINTER(c_int), POINTER(POINTER(Subgroup))] + + _lib.msymSelectSubgroup.restype = _ReturnCode + _lib.msymSelectSubgroup.argtypes = [_Context, POINTER(Subgroup)] + + + + _lib.msymGetEquivalenceSets.restype = _ReturnCode + _lib.msymGetEquivalenceSets.argtypes = [_Context, POINTER(c_int), POINTER(POINTER(EquivalenceSets))] + + if np is None: _SALCsMatrix = c_void_p _SALCsSpecies = POINTER(c_int) @@ -416,12 +592,15 @@ class Context(object): def __init__(self, elements=[], basis_functions=[], point_group=""): if(_lib is None): raise Error("Shared library not loaded") - + self._elements = [] self._basis_functions = [] + self._salc_wf = None self._point_group = None self._subrepresentation_spaces = None self._character_table = None + self._subgroups = None + self._equivalence_sets = None self._ctx = _lib.msymCreateContext() if not self._ctx: raise RuntimeError('Failed to create libmsym context') @@ -512,6 +691,8 @@ def _update_elements(self): self._assert_success(_lib.msymGetElements(self._ctx,byref(csize),byref(celements))) self._elements_array = celements self._elements = celements[0:csize.value] + for i, ele in enumerate(self._elements): + ele.index = i+1 def _update_symmetry_operations(self): if not self._ctx: @@ -569,7 +750,6 @@ def set_thresholds(self, **kwargs): setattr(self._thresholds, key, kwargs[key]) self._assert_success(_lib.msymSetThresholds(self._ctx, pointer(self._thresholds))) - @property def elements(self): return self._elements @@ -605,7 +785,130 @@ def find_symmetry(self): self._update_point_group() self._update_symmetry_operations() return self._point_group + + def _update_equivalence_sets(self): + num = c_int(0) + sets_ptr = POINTER(EquivalenceSets)() + self._assert_success(_lib.msymGetEquivalenceSets(self._ctx, byref(num), byref(sets_ptr))) + sets = sets_ptr[0:num.value] + for s in sets: + s._update_elements(self._elements) + self._equivalence_sets = sets + + def _update_subgroups(self, group_sel = None): + num_subgroups = c_int(0) + subgroups = POINTER(Subgroup)() + self._assert_success(_lib.msymGetSubgroups(self._ctx, byref(num_subgroups), byref(subgroups))) + sgs = subgroups[0:num_subgroups.value] + for s in sgs: + s._update_symmetry_operations(self._symmetry_operations) + self._subgroups = sgs + return self._select_subgroup(group_sel) + + def _select_subgroup(self, group_sel): + if group_sel is None: + return 0 + if isinstance(group_sel, int): + if group_sel == 0: + return 0 + else: + subgroup = self._subgroups[group_sel-1] + elif isinstance(group_sel, str): + for i, subgroup in enumerate(self._subgroups): + if subgroup.name == group_sel: + group_sel = i + 1 + break + self._assert_success(_lib.msymSelectSubgroup(self._ctx, byref(subgroup))) + self._update_point_group() + self._update_symmetry_operations() + return group_sel + + + + def _update_SALC_wavefunctions(self): + salcs, species, partner_functions = self.salcs + self._salc_wf = [] + for i, spec in enumerate(self.character_table.symmetry_species): + sel = species == i + data = salcs[sel, :] + partners = [partner_functions[i] for i, v in enumerate(sel) if v == True] + self._salc_wf.append(SALC_wf(data, partners, spec)) + + def _update_basis_function_element(self): + for bf in self.basis_functions: + element = bf.element + bf.element = self.elements[element.index-1] + def find_misc(self, group_sel = None, to_print = True): + if not self._ctx: + raise RuntimeError + result = (c_double * 3)() + self._assert_success(_lib.msymGetCenterOfMass(self._ctx, result)) + self.COM = np.ctypeslib.as_array(result) + + radius = c_double() + self._assert_success(_lib.msymGetRadius(self._ctx, byref(radius))) + self.radius = float(radius.value) + subgroup_index = self._update_subgroups(group_sel) + self.symmetrize_elements() + + self._update_equivalence_sets() + self._update_basis_function_element() + + if to_print: + print(f"Molecule COM: {self.COM} and radius: {self.radius:.6f}") + print(f"Found point group [0] {self._point_group} with {len(self.subgroups)} subgroups:") + print(f"\t[{subgroup_index}] {self._point_group}\n\t-------") + n = len(self.subgroups) + m = n // 4 + for i in range(m): + for j in range(4): + k = j * m + i + if k < n: + print(f"\t[{k+1:3}] {self.subgroups[k].name:8}", end = "") + print("") + + print(f"There are {len(self.equivalence_sets)} symmetrically equivalent sets:"); + for i, s in enumerate(self.equivalence_sets): + print(f"\t[{i}] {s.length} {s.elements[0].name} atoms:") + print(f"\nThere are {len(self.symmetry_operations)} symmetry operations:"); + for i, s in enumerate(self.symmetry_operations): + print(f"\t[{i}] {s}") + + symmetry_species = self.character_table.symmetry_species + print(f"\nGenerated SALCs from {len(self.basis_functions)} basis functions of {len(symmetry_species)} symmetry species.") + for i, spec in enumerate(symmetry_species): + print(f"\t[{i}] {spec.name} " + f"({self.subrepresentation_spaces[i]._salc_length} SALCs with {spec._d} partner functions each)") + + for i, spec in enumerate(symmetry_species): + print(f"\n---- {spec.name}({i}) ----") + for j, salc in enumerate(self.subrepresentation_spaces[i].salcs): + bf = salc.basis_functions[0] + element = bf.element + eset = bf.element.equiv + eset_index = self.equivalence_sets.index(eset) + print(f"\t[{j:3d}] {element.name:>3}({eset_index}) {eset.length} x {bf.l*2+1}{atomic_orbital_symbols[bf.l]} = {salc._fl} orbitals") + + calc_cbr = 0 + n_wf = 0 + for i, wf in enumerate(self.salc_wf): + print(f"\n({i:d}) ", end= "") + wf.print_coeff(self.basis_functions) + n = wf.coeff.shape[0] + if n != 0: + n_wf += 1 + calc_cbr += n**3 + n_bf_cbr = len(self.basis_functions)**3 + print(f"non-empty SALCs: {n_wf}, Cost: {calc_cbr} vs {n_bf_cbr}, Acc.: {n_bf_cbr/calc_cbr:.6f}") + + + def _find_equivalent_set(self, element): + for equivalent_set in self.equivalence_sets: + if element in equivalent_set.elements: + return equivalent_set + return None + def symmetrize_elements(self): if not self._ctx: raise RuntimeError @@ -617,7 +920,7 @@ def symmetrize_elements(self): def align_axes(self): if not self._ctx: raise RuntimeError - cerror = c_double(0) + #cerror = c_double(0) self._assert_success(_lib.msymAlignAxes(self._ctx)) self._update_elements() return self._elements @@ -637,6 +940,20 @@ def character_table(self): return self._character_table + @property + def equivalence_sets(self): + if self._equivalence_sets is None: + self._update_equivalence_sets() + + return self._equivalence_sets + + @property + def subgroups(self): + if self._subgroups is None: + self._update_subgroups() + + return self._subgroups + @property def salcs(self): if self._salcs is None: @@ -644,6 +961,14 @@ def salcs(self): return self._salcs + @property + def salc_wf(self): + if self._salc_wf is None: + self._update_SALC_wavefunctions() + + return self._salc_wf + + def symmetrize_wavefunctions(self,m): if not self._ctx: raise RuntimeError diff --git a/src/elements.c b/src/elements.c index 0d591d2..f23e356 100644 --- a/src/elements.c +++ b/src/elements.c @@ -15,6 +15,11 @@ #include "debug.h" +#define EXTRA_NUM 1024 + +static char extra_elements[EXTRA_NUM][4]; +static int n_extra = 0; + const struct _periodic_table { int n; char *name; @@ -204,11 +209,32 @@ msym_error_t complementElementData(msym_element_t *element){ } if(fi == fil){ + for (fi =0; fi < n_extra; fi++) { + if (0 == strncmp(element->name, extra_elements[fi], sizeof(extra_elements[fi]))) { + int n = fi + 1000; + if(element->m <= 0.0) element->m = (double) n; + if(element->n <= 0) element->n = n; + } + } + if (fi == n_extra) { + if (n_extra == EXTRA_NUM) { + msymSetErrorDetails("Cannot set extra elements more than %d", EXTRA_NUM); + ret = MSYM_INVALID_ELEMENTS; + goto err; + } + strncpy(extra_elements[n_extra], element->name, sizeof(extra_elements[n_extra])); + int n = n_extra + 1000; + if(element->m <= 0.0) element->m = (double) n; + if(element->n <= 0) element->n = n; + n_extra++; + } + /* char buf[sizeof(element->name)]; snprintf(buf, sizeof(element->name), "%s",element->name); //in case someone forgets to null terminate msymSetErrorDetails("Unknown element with name %s",buf); ret = MSYM_INVALID_ELEMENTS; goto err; + */ } } else if(element->m > 0.0 && (strl <= 0 || element->n <= 0)){ int fim = 0, fil = sizeof(periodic_table)/sizeof(periodic_table[0]);