Skip to content

Commit 636898f

Browse files
Gareth Aneurin TribelloGareth Aneurin Tribello
Gareth Aneurin Tribello
authored and
Gareth Aneurin Tribello
committed
Reduced the amount of memory that is used in Q6 so you can now run 4 threads
1 parent 8628615 commit 636898f

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

src/core/ActionWithVector.cpp

+13-9
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ bool ActionWithVector::checkChainForNonScalarForces() const {
332332
return false;
333333
}
334334

335+
void ActionWithVector::getNumberOfForceDerivatives( unsigned& nforces, unsigned& nderiv ) const {
336+
nforces=0; unsigned nargs = getNumberOfArguments(); int nmasks = getNumberOfMasks();
337+
if( nargs>=nmasks && nmasks>0 ) nargs = nargs - nmasks;
338+
if( getNumberOfAtoms()>0 ) nforces += 3*getNumberOfAtoms() + 9;
339+
for(unsigned i=0; i<nargs; ++i) {
340+
nforces += getPntrToArgument(i)->getNumberOfStoredValues();
341+
}
342+
nderiv = nforces;
343+
}
344+
335345
bool ActionWithVector::checkForForces() {
336346
if( getPntrToComponent(0)->getRank()==0 ) return ActionWithValue::checkForForces();
337347

@@ -357,13 +367,8 @@ bool ActionWithVector::checkForForces() {
357367
if( omp_forces.size()!=nt ) omp_forces.resize(nt);
358368

359369
// Recover the number of derivatives we require (this should be equal to the number of forces)
360-
unsigned nderiv=0, nargs = getNumberOfArguments(); int nmasks = getNumberOfMasks();
361-
if( nargs>=nmasks && nmasks>0 ) nargs = nargs - nmasks;
362-
if( getNumberOfAtoms()>0 ) nderiv += 3*getNumberOfAtoms() + 9;
363-
for(unsigned i=0; i<nargs; ++i) {
364-
nderiv += getPntrToArgument(i)->getNumberOfStoredValues();
365-
}
366-
if( forcesForApply.size()!=nderiv ) forcesForApply.resize( nderiv );
370+
unsigned nderiv, nforces; getNumberOfForceDerivatives( nforces, nderiv );
371+
if( forcesForApply.size()!=nforces ) forcesForApply.resize( nforces );
367372
// Clear force buffer
368373
forcesForApply.assign( forcesForApply.size(), 0.0 );
369374

@@ -408,8 +413,7 @@ bool ActionWithVector::checkForTaskForce( const unsigned& itask, const Value* my
408413
}
409414

410415
void ActionWithVector::gatherForcesOnStoredValue( const unsigned& ival, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
411-
const Value* myval = getConstPntrToComponent(ival);
412-
double fforce = myval->getForce(itask);
416+
const Value* myval = getConstPntrToComponent(ival); double fforce = myval->getForce(itask);
413417
for(unsigned j=0; j<myvals.getNumberActive(ival); ++j) {
414418
unsigned jder=myvals.getActiveIndex(ival, j); plumed_dbg_assert( jder<forces.size() );
415419
forces[jder] += fforce*myvals.getDerivative( ival, jder );

src/core/ActionWithVector.h

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class ActionWithVector:
110110
virtual void gatherForces( const unsigned& i, const MultiValue& myvals, std::vector<double>& forces ) const ;
111111
/// This is to transfer data from the buffer to the final value
112112
void finishComputations( const std::vector<double>& buf );
113+
/// Get the number of forces to use
114+
virtual void getNumberOfForceDerivatives( unsigned& nforces, unsigned& nderiv ) const ;
113115
/// Apply the forces on this data
114116
virtual void apply();
115117
};

src/matrixtools/MatrixTimesVector.cpp

+27-3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ class MatrixTimesVector : public ActionWithVector {
4646
void calculate() override ;
4747
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
4848
int checkTaskIsActive( const unsigned& itask ) const override ;
49+
void getNumberOfForceDerivatives( unsigned& nforces, unsigned& nderiv ) const override ;
50+
void gatherForces( const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override ;
4951
};
5052

5153
PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
@@ -173,7 +175,7 @@ int MatrixTimesVector::checkTaskIsActive( const unsigned& itask ) const {
173175

174176
void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const {
175177
if( sumrows ) {
176-
unsigned base=0, n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n);
178+
unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n);
177179
for(unsigned i=0; i<n; ++i) {
178180
Value* mymat = getPntrToArgument(i);
179181
unsigned ncol = mymat->getNumberOfColumns();
@@ -184,11 +186,10 @@ void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myv
184186
// And the derivatives
185187
if( doNotCalculateDerivatives() ) continue;
186188

187-
unsigned dloc = base + task_index*ncol;
189+
unsigned dloc = task_index*ncol;
188190
for(unsigned j=0; j<nmat; ++j) {
189191
myvals.addDerivative( i, dloc + j, 1.0 ); myvals.updateIndex( i, dloc + j );
190192
}
191-
base += mymat->getNumberOfStoredValues();
192193
}
193194
} else if( getPntrToArgument(1)->getRank()==1 ) {
194195
Value* mymat = getPntrToArgument(0);
@@ -238,5 +239,28 @@ void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myv
238239
}
239240
}
240241

242+
void MatrixTimesVector::getNumberOfForceDerivatives( unsigned& nforces, unsigned& nderiv ) const {
243+
ActionWithVector::getNumberOfForceDerivatives( nforces, nderiv );
244+
if( sumrows ) nderiv = getPntrToArgument(0)->getNumberOfStoredValues() + getPntrToArgument(getNumberOfArguments()-1)->getNumberOfStoredValues();
245+
}
246+
247+
void MatrixTimesVector::gatherForces( const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
248+
if( !sumrows ) { ActionWithVector::gatherForces( itask, myvals, forces ); return; }
249+
if( checkComponentsForForce() ) {
250+
unsigned base = 0;
251+
for(unsigned ival=0; ival<getNumberOfComponents(); ++ival) {
252+
const Value* myval=getConstPntrToComponent(ival);
253+
if( myval->forcesWereAdded() ) {
254+
double fforce = myval->getForce(itask);
255+
for(unsigned j=0; j<myvals.getNumberActive(ival); ++j) {
256+
unsigned jder=myvals.getActiveIndex(ival, j); plumed_dbg_assert( jder<forces.size() );
257+
forces[base+jder] += fforce*myvals.getDerivative( ival, jder );
258+
}
259+
}
260+
base += getPntrToArgument(ival)->getNumberOfStoredValues();
261+
}
262+
}
263+
}
264+
241265
}
242266
}

0 commit comments

Comments
 (0)