@@ -1642,88 +1642,96 @@ namespace dftfe
1642
1642
& sphericalFunctionKetTimesVectorParFlattened,
1643
1643
const bool flagCopyResultsToMatrix)
1644
1644
{
1645
- if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace )
1645
+ if (d_totalNonLocalEntries > 0 )
1646
1646
{
1647
- const std::vector<unsigned int > &atomicNumber =
1648
- d_atomCenteredSphericalFunctionContainer->getAtomicNumbers ();
1649
- const std::vector<unsigned int > atomIdsInProc =
1650
- d_atomCenteredSphericalFunctionContainer
1651
- ->getAtomIdsInCurrentProcess ();
1652
- if (couplingtype == CouplingStructure::diagonal)
1647
+ if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
1653
1648
{
1654
- unsigned int startIndex = 0 ;
1655
- const unsigned int inc = 1 ;
1656
- for (int iAtom = 0 ; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1649
+ const std::vector<unsigned int > &atomicNumber =
1650
+ d_atomCenteredSphericalFunctionContainer->getAtomicNumbers ();
1651
+ const std::vector<unsigned int > atomIdsInProc =
1652
+ d_atomCenteredSphericalFunctionContainer
1653
+ ->getAtomIdsInCurrentProcess ();
1654
+ if (couplingtype == CouplingStructure::diagonal)
1657
1655
{
1658
- const unsigned int atomId = atomIdsInProc[iAtom];
1659
- const unsigned int Znum = atomicNumber[atomId];
1660
- const unsigned int numberSphericalFunctions =
1661
- d_atomCenteredSphericalFunctionContainer
1662
- ->getTotalNumberOfSphericalFunctionsPerAtom (Znum);
1656
+ unsigned int startIndex = 0 ;
1657
+ const unsigned int inc = 1 ;
1658
+ for (int iAtom = 0 ; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1659
+ {
1660
+ const unsigned int atomId = atomIdsInProc[iAtom];
1661
+ const unsigned int Znum = atomicNumber[atomId];
1662
+ const unsigned int numberSphericalFunctions =
1663
+ d_atomCenteredSphericalFunctionContainer
1664
+ ->getTotalNumberOfSphericalFunctionsPerAtom (Znum);
1663
1665
1664
1666
1665
- for (unsigned int alpha = 0 ; alpha < numberSphericalFunctions;
1666
- alpha++)
1667
- {
1668
- ValueType nonlocalConstantV = couplingMatrix[startIndex++];
1669
- const unsigned int localId =
1670
- sphericalFunctionKetTimesVectorParFlattened
1671
- .getMPIPatternP2P ()
1672
- ->globalToLocal (
1673
- d_sphericalFunctionIdsNumberingMapCurrentProcess
1674
- .find (std::make_pair (atomId, alpha))
1675
- ->second );
1676
- if (flagCopyResultsToMatrix)
1677
- {
1678
- std::transform (
1679
- sphericalFunctionKetTimesVectorParFlattened.begin () +
1680
- localId * d_numberWaveFunctions,
1681
- sphericalFunctionKetTimesVectorParFlattened.begin () +
1682
- localId * d_numberWaveFunctions +
1683
- d_numberWaveFunctions,
1684
- d_sphericalFnTimesWavefunMatrix[atomId].begin () +
1685
- d_numberWaveFunctions * alpha,
1686
- [&nonlocalConstantV](auto &a) {
1687
- return nonlocalConstantV * a;
1688
- });
1689
- }
1690
- else
1667
+ for (unsigned int alpha = 0 ;
1668
+ alpha < numberSphericalFunctions;
1669
+ alpha++)
1691
1670
{
1692
- d_BLASWrapperPtr->xscal (
1693
- sphericalFunctionKetTimesVectorParFlattened.begin () +
1694
- localId * d_numberWaveFunctions,
1695
- nonlocalConstantV,
1696
- d_numberWaveFunctions);
1671
+ ValueType nonlocalConstantV =
1672
+ couplingMatrix[startIndex++];
1673
+ const unsigned int localId =
1674
+ sphericalFunctionKetTimesVectorParFlattened
1675
+ .getMPIPatternP2P ()
1676
+ ->globalToLocal (
1677
+ d_sphericalFunctionIdsNumberingMapCurrentProcess
1678
+ .find (std::make_pair (atomId, alpha))
1679
+ ->second );
1680
+ if (flagCopyResultsToMatrix)
1681
+ {
1682
+ std::transform (
1683
+ sphericalFunctionKetTimesVectorParFlattened
1684
+ .begin () +
1685
+ localId * d_numberWaveFunctions,
1686
+ sphericalFunctionKetTimesVectorParFlattened
1687
+ .begin () +
1688
+ localId * d_numberWaveFunctions +
1689
+ d_numberWaveFunctions,
1690
+ d_sphericalFnTimesWavefunMatrix[atomId].begin () +
1691
+ d_numberWaveFunctions * alpha,
1692
+ [&nonlocalConstantV](auto &a) {
1693
+ return nonlocalConstantV * a;
1694
+ });
1695
+ }
1696
+ else
1697
+ {
1698
+ d_BLASWrapperPtr->xscal (
1699
+ sphericalFunctionKetTimesVectorParFlattened
1700
+ .begin () +
1701
+ localId * d_numberWaveFunctions,
1702
+ nonlocalConstantV,
1703
+ d_numberWaveFunctions);
1704
+ }
1697
1705
}
1698
1706
}
1699
1707
}
1700
1708
}
1701
- }
1702
1709
#if defined(DFTFE_WITH_DEVICE)
1703
- else
1704
- {
1705
- if (couplingtype == CouplingStructure::diagonal)
1710
+ else
1706
1711
{
1707
- d_BLASWrapperPtr->stridedBlockScale (
1708
- d_numberWaveFunctions,
1709
- d_totalNonLocalEntries,
1710
- ValueType (1.0 ),
1711
- couplingMatrix.begin (),
1712
- sphericalFunctionKetTimesVectorParFlattened.begin ());
1713
- }
1712
+ if (couplingtype == CouplingStructure::diagonal)
1713
+ {
1714
+ d_BLASWrapperPtr->stridedBlockScale (
1715
+ d_numberWaveFunctions,
1716
+ d_totalNonLocalEntries,
1717
+ ValueType (1.0 ),
1718
+ couplingMatrix.begin (),
1719
+ sphericalFunctionKetTimesVectorParFlattened.begin ());
1720
+ }
1714
1721
1715
- if (flagCopyResultsToMatrix)
1716
- dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1717
- copyFromParallelNonLocalVecToAllCellsVec (
1718
- d_numberWaveFunctions,
1719
- d_totalNonlocalElems,
1720
- d_maxSingleAtomContribution,
1721
- sphericalFunctionKetTimesVectorParFlattened.begin (),
1722
- d_sphericalFnTimesVectorAllCellsDevice.begin (),
1723
- d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice
1724
- .begin ());
1725
- }
1722
+ if (flagCopyResultsToMatrix)
1723
+ dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1724
+ copyFromParallelNonLocalVecToAllCellsVec (
1725
+ d_numberWaveFunctions,
1726
+ d_totalNonlocalElems,
1727
+ d_maxSingleAtomContribution,
1728
+ sphericalFunctionKetTimesVectorParFlattened.begin (),
1729
+ d_sphericalFnTimesVectorAllCellsDevice.begin (),
1730
+ d_indexMapFromPaddedNonLocalVecToParallelNonLocalVecDevice
1731
+ .begin ());
1732
+ }
1726
1733
#endif
1734
+ }
1727
1735
}
1728
1736
1729
1737
template <typename ValueType, dftfe::utils::MemorySpace memorySpace>
@@ -1734,75 +1742,81 @@ namespace dftfe
1734
1742
& sphericalFunctionKetTimesVectorParFlattened,
1735
1743
const bool skipComm)
1736
1744
{
1737
- if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace )
1745
+ if (d_totalNonLocalEntries > 0 )
1738
1746
{
1739
- const std::vector<unsigned int > atomIdsInProc =
1740
- d_atomCenteredSphericalFunctionContainer
1741
- ->getAtomIdsInCurrentProcess ();
1742
- const std::vector<unsigned int > &atomicNumber =
1743
- d_atomCenteredSphericalFunctionContainer->getAtomicNumbers ();
1744
- for (int iAtom = 0 ; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1747
+ if constexpr (dftfe::utils::MemorySpace::HOST == memorySpace)
1745
1748
{
1746
- const unsigned int atomId = atomIdsInProc[iAtom];
1747
- unsigned int Znum = atomicNumber[atomId];
1748
- const unsigned int numberSphericalFunctions =
1749
+ const std::vector<unsigned int > atomIdsInProc =
1749
1750
d_atomCenteredSphericalFunctionContainer
1750
- ->getTotalNumberOfSphericalFunctionsPerAtom (Znum);
1751
- for (unsigned int alpha = 0 ; alpha < numberSphericalFunctions;
1752
- alpha++)
1751
+ ->getAtomIdsInCurrentProcess ();
1752
+ const std::vector<unsigned int > &atomicNumber =
1753
+ d_atomCenteredSphericalFunctionContainer->getAtomicNumbers ();
1754
+ for (int iAtom = 0 ; iAtom < d_totalAtomsInCurrentProc; iAtom++)
1753
1755
{
1754
- const unsigned int id =
1755
- d_sphericalFunctionIdsNumberingMapCurrentProcess
1756
- .find (std::make_pair (atomId, alpha))
1757
- ->second ;
1758
- std::memcpy (sphericalFunctionKetTimesVectorParFlattened.data () +
1759
- sphericalFunctionKetTimesVectorParFlattened
1760
- .getMPIPatternP2P ()
1761
- ->globalToLocal (id) *
1762
- d_numberWaveFunctions,
1763
- d_sphericalFnTimesWavefunMatrix[atomId].begin () +
1764
- d_numberWaveFunctions * alpha,
1765
- d_numberWaveFunctions * sizeof (ValueType));
1766
-
1767
-
1768
- // d_BLASWrapperPtr->xcopy(
1769
- // d_numberWaveFunctions,
1770
- // &d_sphericalFnTimesWavefunMatrix[atomId]
1771
- // [d_numberWaveFunctions *
1772
- // alpha],
1773
- // inc,
1774
- // sphericalFunctionKetTimesVectorParFlattened.data() +
1775
- // sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P()
1776
- // ->globalToLocal(id) *d_numberWaveFunctions,
1777
- // inc);
1756
+ const unsigned int atomId = atomIdsInProc[iAtom];
1757
+ unsigned int Znum = atomicNumber[atomId];
1758
+ const unsigned int numberSphericalFunctions =
1759
+ d_atomCenteredSphericalFunctionContainer
1760
+ ->getTotalNumberOfSphericalFunctionsPerAtom (Znum);
1761
+ for (unsigned int alpha = 0 ; alpha < numberSphericalFunctions;
1762
+ alpha++)
1763
+ {
1764
+ const unsigned int id =
1765
+ d_sphericalFunctionIdsNumberingMapCurrentProcess
1766
+ .find (std::make_pair (atomId, alpha))
1767
+ ->second ;
1768
+ std::memcpy (
1769
+ sphericalFunctionKetTimesVectorParFlattened.data () +
1770
+ sphericalFunctionKetTimesVectorParFlattened
1771
+ .getMPIPatternP2P ()
1772
+ ->globalToLocal (id) *
1773
+ d_numberWaveFunctions,
1774
+ d_sphericalFnTimesWavefunMatrix[atomId].begin () +
1775
+ d_numberWaveFunctions * alpha,
1776
+ d_numberWaveFunctions * sizeof (ValueType));
1777
+
1778
+
1779
+ // d_BLASWrapperPtr->xcopy(
1780
+ // d_numberWaveFunctions,
1781
+ // &d_sphericalFnTimesWavefunMatrix[atomId]
1782
+ // [d_numberWaveFunctions *
1783
+ // alpha],
1784
+ // inc,
1785
+ // sphericalFunctionKetTimesVectorParFlattened.data() +
1786
+ // sphericalFunctionKetTimesVectorParFlattened.getMPIPatternP2P()
1787
+ // ->globalToLocal(id) *d_numberWaveFunctions,
1788
+ // inc);
1789
+ }
1790
+ }
1791
+ if (!skipComm)
1792
+ {
1793
+ sphericalFunctionKetTimesVectorParFlattened
1794
+ .accumulateAddLocallyOwned (1 );
1795
+ sphericalFunctionKetTimesVectorParFlattened.updateGhostValues (
1796
+ 1 );
1778
1797
}
1779
1798
}
1780
- if (!skipComm)
1781
- {
1782
- sphericalFunctionKetTimesVectorParFlattened
1783
- .accumulateAddLocallyOwned (1 );
1784
- sphericalFunctionKetTimesVectorParFlattened.updateGhostValues (1 );
1785
- }
1786
- }
1787
1799
#if defined(DFTFE_WITH_DEVICE)
1788
- else
1789
- {
1790
- dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1791
- copyToDealiiParallelNonLocalVec (
1792
- d_numberWaveFunctions,
1793
- d_totalNonLocalEntries,
1794
- d_sphericalFnTimesWavefunctionMatrix.begin (),
1795
- sphericalFunctionKetTimesVectorParFlattened.begin (),
1796
- d_sphericalFnIdsParallelNumberingMapDevice.begin ());
1797
-
1798
- if (!skipComm)
1800
+ else
1799
1801
{
1800
- sphericalFunctionKetTimesVectorParFlattened
1801
- .accumulateAddLocallyOwned (1 );
1802
- sphericalFunctionKetTimesVectorParFlattened.updateGhostValues (1 );
1802
+ dftfe::AtomicCenteredNonLocalOperatorKernelsDevice::
1803
+ copyToDealiiParallelNonLocalVec (
1804
+ d_numberWaveFunctions,
1805
+ d_totalNonLocalEntries,
1806
+ d_sphericalFnTimesWavefunctionMatrix.begin (),
1807
+ sphericalFunctionKetTimesVectorParFlattened.begin (),
1808
+ d_sphericalFnIdsParallelNumberingMapDevice.begin ());
1809
+
1810
+ if (!skipComm)
1811
+ {
1812
+ sphericalFunctionKetTimesVectorParFlattened
1813
+ .accumulateAddLocallyOwned (1 );
1814
+ sphericalFunctionKetTimesVectorParFlattened.updateGhostValues (
1815
+ 1 );
1816
+ }
1803
1817
}
1804
- }
1805
1818
#endif
1819
+ }
1806
1820
}
1807
1821
template <typename ValueType, dftfe::utils::MemorySpace memorySpace>
1808
1822
void
0 commit comments