Skip to content

Commit 1bfaec9

Browse files
phanimotamarridsambit
authored andcommitted
Merged in nonLocalOperatorBugFix (pull request #593)
Bug fix in nonlocalOpearotr when nonLocalEntries =0 Approved-by: Sambit Das
2 parents 538a294 + 5c74937 commit 1bfaec9

File tree

1 file changed

+141
-127
lines changed

1 file changed

+141
-127
lines changed

src/atom/AtomicCenteredNonLocalOperator.t.cc

+141-127
Original file line numberDiff line numberDiff line change
@@ -1642,88 +1642,96 @@ namespace dftfe
16421642
& sphericalFunctionKetTimesVectorParFlattened,
16431643
const bool flagCopyResultsToMatrix)
16441644
{
1645-
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
1645+
if (d_totalNonLocalEntries > 0)
16461646
{
1647-
const std::vector<unsigned int> &atomicNumber =
1648-
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
1649-
const std::vector<unsigned int> atomIdsInProc =
1650-
d_atomCenteredSphericalFunctionContainer
1651-
->getAtomIdsInCurrentProcess();
1652-
if (couplingtype == CouplingStructure::diagonal)
1647+
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
16531648
{
1654-
unsigned int startIndex = 0;
1655-
const unsigned int inc = 1;
1656-
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1649+
const std::vector<unsigned int> &atomicNumber =
1650+
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
1651+
const std::vector<unsigned int> atomIdsInProc =
1652+
d_atomCenteredSphericalFunctionContainer
1653+
->getAtomIdsInCurrentProcess();
1654+
if (couplingtype == CouplingStructure::diagonal)
16571655
{
1658-
const unsigned int atomId = atomIdsInProc[iAtom];
1659-
const unsigned int Znum = atomicNumber[atomId];
1660-
const unsigned int numberSphericalFunctions =
1661-
d_atomCenteredSphericalFunctionContainer
1662-
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
1656+
unsigned int startIndex = 0;
1657+
const unsigned int inc = 1;
1658+
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1659+
{
1660+
const unsigned int atomId = atomIdsInProc[iAtom];
1661+
const unsigned int Znum = atomicNumber[atomId];
1662+
const unsigned int numberSphericalFunctions =
1663+
d_atomCenteredSphericalFunctionContainer
1664+
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
16631665

16641666

1665-
for (unsigned int alpha = 0; alpha < numberSphericalFunctions;
1666-
alpha++)
1667-
{
1668-
ValueType nonlocalConstantV = couplingMatrix[startIndex++];
1669-
const unsigned int localId =
1670-
sphericalFunctionKetTimesVectorParFlattened
1671-
.getMPIPatternP2P()
1672-
->globalToLocal(
1673-
d_sphericalFunctionIdsNumberingMapCurrentProcess
1674-
.find(std::make_pair(atomId, alpha))
1675-
->second);
1676-
if (flagCopyResultsToMatrix)
1677-
{
1678-
std::transform(
1679-
sphericalFunctionKetTimesVectorParFlattened.begin() +
1680-
localId * d_numberWaveFunctions,
1681-
sphericalFunctionKetTimesVectorParFlattened.begin() +
1682-
localId * d_numberWaveFunctions +
1683-
d_numberWaveFunctions,
1684-
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
1685-
d_numberWaveFunctions * alpha,
1686-
[&nonlocalConstantV](auto &a) {
1687-
return nonlocalConstantV * a;
1688-
});
1689-
}
1690-
else
1667+
for (unsigned int alpha = 0;
1668+
alpha < numberSphericalFunctions;
1669+
alpha++)
16911670
{
1692-
d_BLASWrapperPtr->xscal(
1693-
sphericalFunctionKetTimesVectorParFlattened.begin() +
1694-
localId * d_numberWaveFunctions,
1695-
nonlocalConstantV,
1696-
d_numberWaveFunctions);
1671+
ValueType nonlocalConstantV =
1672+
couplingMatrix[startIndex++];
1673+
const unsigned int localId =
1674+
sphericalFunctionKetTimesVectorParFlattened
1675+
.getMPIPatternP2P()
1676+
->globalToLocal(
1677+
d_sphericalFunctionIdsNumberingMapCurrentProcess
1678+
.find(std::make_pair(atomId, alpha))
1679+
->second);
1680+
if (flagCopyResultsToMatrix)
1681+
{
1682+
std::transform(
1683+
sphericalFunctionKetTimesVectorParFlattened
1684+
.begin() +
1685+
localId * d_numberWaveFunctions,
1686+
sphericalFunctionKetTimesVectorParFlattened
1687+
.begin() +
1688+
localId * d_numberWaveFunctions +
1689+
d_numberWaveFunctions,
1690+
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
1691+
d_numberWaveFunctions * alpha,
1692+
[&nonlocalConstantV](auto &a) {
1693+
return nonlocalConstantV * a;
1694+
});
1695+
}
1696+
else
1697+
{
1698+
d_BLASWrapperPtr->xscal(
1699+
sphericalFunctionKetTimesVectorParFlattened
1700+
.begin() +
1701+
localId * d_numberWaveFunctions,
1702+
nonlocalConstantV,
1703+
d_numberWaveFunctions);
1704+
}
16971705
}
16981706
}
16991707
}
17001708
}
1701-
}
17021709
#if defined(DFTFE_WITH_DEVICE)
1703-
else
1704-
{
1705-
if (couplingtype == CouplingStructure::diagonal)
1710+
else
17061711
{
1707-
d_BLASWrapperPtr->stridedBlockScale(
1708-
d_numberWaveFunctions,
1709-
d_totalNonLocalEntries,
1710-
ValueType(1.0),
1711-
couplingMatrix.begin(),
1712-
sphericalFunctionKetTimesVectorParFlattened.begin());
1713-
}
1712+
if (couplingtype == CouplingStructure::diagonal)
1713+
{
1714+
d_BLASWrapperPtr->stridedBlockScale(
1715+
d_numberWaveFunctions,
1716+
d_totalNonLocalEntries,
1717+
ValueType(1.0),
1718+
couplingMatrix.begin(),
1719+
sphericalFunctionKetTimesVectorParFlattened.begin());
1720+
}
17141721

1715-
if (flagCopyResultsToMatrix)
1716-
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1717-
copyFromParallelNonLocalVecToAllCellsVec(
1718-
d_numberWaveFunctions,
1719-
d_totalNonlocalElems,
1720-
d_maxSingleAtomContribution,
1721-
sphericalFunctionKetTimesVectorParFlattened.begin(),
1722-
d_sphericalFnTimesVectorAllCellsDevice.begin(),
1723-
d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice
1724-
.begin());
1725-
}
1722+
if (flagCopyResultsToMatrix)
1723+
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1724+
copyFromParallelNonLocalVecToAllCellsVec(
1725+
d_numberWaveFunctions,
1726+
d_totalNonlocalElems,
1727+
d_maxSingleAtomContribution,
1728+
sphericalFunctionKetTimesVectorParFlattened.begin(),
1729+
d_sphericalFnTimesVectorAllCellsDevice.begin(),
1730+
d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice
1731+
.begin());
1732+
}
17261733
#endif
1734+
}
17271735
}
17281736

17291737
template <typename ValueType, dftfe::utils::MemorySpace memorySpace>
@@ -1734,75 +1742,81 @@ namespace dftfe
17341742
& sphericalFunctionKetTimesVectorParFlattened,
17351743
const bool skipComm)
17361744
{
1737-
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
1745+
if (d_totalNonLocalEntries > 0)
17381746
{
1739-
const std::vector<unsigned int> atomIdsInProc =
1740-
d_atomCenteredSphericalFunctionContainer
1741-
->getAtomIdsInCurrentProcess();
1742-
const std::vector<unsigned int> &atomicNumber =
1743-
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
1744-
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1747+
if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
17451748
{
1746-
const unsigned int atomId = atomIdsInProc[iAtom];
1747-
unsigned int Znum = atomicNumber[atomId];
1748-
const unsigned int numberSphericalFunctions =
1749+
const std::vector<unsigned int> atomIdsInProc =
17491750
d_atomCenteredSphericalFunctionContainer
1750-
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
1751-
for (unsigned int alpha = 0; alpha < numberSphericalFunctions;
1752-
alpha++)
1751+
->getAtomIdsInCurrentProcess();
1752+
const std::vector<unsigned int> &atomicNumber =
1753+
d_atomCenteredSphericalFunctionContainer->getAtomicNumbers();
1754+
for (int iAtom = 0; iAtom < d_totalAtomsInCurrentProc; iAtom++)
17531755
{
1754-
const unsigned int id =
1755-
d_sphericalFunctionIdsNumberingMapCurrentProcess
1756-
.find(std::make_pair(atomId, alpha))
1757-
->second;
1758-
std::memcpy(sphericalFunctionKetTimesVectorParFlattened.data() +
1759-
sphericalFunctionKetTimesVectorParFlattened
1760-
.getMPIPatternP2P()
1761-
->globalToLocal(id) *
1762-
d_numberWaveFunctions,
1763-
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
1764-
d_numberWaveFunctions * alpha,
1765-
d_numberWaveFunctions * sizeof(ValueType));
1766-
1767-
1768-
// d_BLASWrapperPtr->xcopy(
1769-
// d_numberWaveFunctions,
1770-
// &d_sphericalFnTimesWavefunMatrix[atomId]
1771-
// [d_numberWaveFunctions *
1772-
// alpha],
1773-
// inc,
1774-
// sphericalFunctionKetTimesVectorParFlattened.data() +
1775-
// sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P()
1776-
// ->globalToLocal(id) *d_numberWaveFunctions,
1777-
// inc);
1756+
const unsigned int atomId = atomIdsInProc[iAtom];
1757+
unsigned int Znum = atomicNumber[atomId];
1758+
const unsigned int numberSphericalFunctions =
1759+
d_atomCenteredSphericalFunctionContainer
1760+
->getTotalNumberOfSphericalFunctionsPerAtom(Znum);
1761+
for (unsigned int alpha = 0; alpha < numberSphericalFunctions;
1762+
alpha++)
1763+
{
1764+
const unsigned int id =
1765+
d_sphericalFunctionIdsNumberingMapCurrentProcess
1766+
.find(std::make_pair(atomId, alpha))
1767+
->second;
1768+
std::memcpy(
1769+
sphericalFunctionKetTimesVectorParFlattened.data() +
1770+
sphericalFunctionKetTimesVectorParFlattened
1771+
.getMPIPatternP2P()
1772+
->globalToLocal(id) *
1773+
d_numberWaveFunctions,
1774+
d_sphericalFnTimesWavefunMatrix[atomId].begin() +
1775+
d_numberWaveFunctions * alpha,
1776+
d_numberWaveFunctions * sizeof(ValueType));
1777+
1778+
1779+
// d_BLASWrapperPtr->xcopy(
1780+
// d_numberWaveFunctions,
1781+
// &d_sphericalFnTimesWavefunMatrix[atomId]
1782+
// [d_numberWaveFunctions *
1783+
// alpha],
1784+
// inc,
1785+
// sphericalFunctionKetTimesVectorParFlattened.data() +
1786+
// sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P()
1787+
// ->globalToLocal(id) *d_numberWaveFunctions,
1788+
// inc);
1789+
}
1790+
}
1791+
if (!skipComm)
1792+
{
1793+
sphericalFunctionKetTimesVectorParFlattened
1794+
.accumulateAddLocallyOwned(1);
1795+
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(
1796+
1);
17781797
}
17791798
}
1780-
if (!skipComm)
1781-
{
1782-
sphericalFunctionKetTimesVectorParFlattened
1783-
.accumulateAddLocallyOwned(1);
1784-
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(1);
1785-
}
1786-
}
17871799
#if defined(DFTFE_WITH_DEVICE)
1788-
else
1789-
{
1790-
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1791-
copyToDealiiParallelNonLocalVec(
1792-
d_numberWaveFunctions,
1793-
d_totalNonLocalEntries,
1794-
d_sphericalFnTimesWavefunctionMatrix.begin(),
1795-
sphericalFunctionKetTimesVectorParFlattened.begin(),
1796-
d_sphericalFnIdsParallelNumberingMapDevice.begin());
1797-
1798-
if (!skipComm)
1800+
else
17991801
{
1800-
sphericalFunctionKetTimesVectorParFlattened
1801-
.accumulateAddLocallyOwned(1);
1802-
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(1);
1802+
dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1803+
copyToDealiiParallelNonLocalVec(
1804+
d_numberWaveFunctions,
1805+
d_totalNonLocalEntries,
1806+
d_sphericalFnTimesWavefunctionMatrix.begin(),
1807+
sphericalFunctionKetTimesVectorParFlattened.begin(),
1808+
d_sphericalFnIdsParallelNumberingMapDevice.begin());
1809+
1810+
if (!skipComm)
1811+
{
1812+
sphericalFunctionKetTimesVectorParFlattened
1813+
.accumulateAddLocallyOwned(1);
1814+
sphericalFunctionKetTimesVectorParFlattened.updateGhostValues(
1815+
1);
1816+
}
18031817
}
1804-
}
18051818
#endif
1819+
}
18061820
}
18071821
template <typename ValueType, dftfe::utils::MemorySpace memorySpace>
18081822
void

0 commit comments

Comments
 (0)