Skip to content

Commit b9b8969

Browse files
author
Gareth Aneurin Tribello
committed
Now turning off derivatives on forward pass through task loop when storing values as they are only needed on the backward pass to calculate forces
1 parent 2982b55 commit b9b8969

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/core/ActionWithVector.cpp

+15-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Option interpretEnvString(const char* env,const char* str) {
3535
if(!std::strcmp(str,"no"))return Option::no;
3636
plumed_error()<<"Cannot understand env var "<<env<<"\nPossible values: yes/no\nActual value: "<<str;
3737
}
38-
38+
3939
/// Switch on/off chains of actions using PLUMED environment variable
4040
/// export PLUMED_FORBID_CHAINS=yes # forbid the use of chains in this run
4141
/// export PLUMED_FORBID_CHAINS=no # allow chains to be used in the run
@@ -44,7 +44,7 @@ Option getenvChainForbidden() {
4444
static const char* name="PLUMED_FORBID_CHAINS";
4545
static const auto opt = interpretEnvString(name,std::getenv(name));
4646
return opt;
47-
}
47+
}
4848

4949
void ActionWithVector::registerKeywords( Keywords& keys ) {
5050
Action::registerKeywords( keys );
@@ -60,6 +60,7 @@ ActionWithVector::ActionWithVector(const ActionOptions&ao):
6060
ActionWithValue(ao),
6161
ActionWithArguments(ao),
6262
serial(false),
63+
forwardPass(false),
6364
action_to_do_before(NULL),
6465
action_to_do_after(NULL),
6566
never_reduce_tasks(false),
@@ -451,6 +452,11 @@ std::vector<unsigned>& ActionWithVector::getListOfActiveTasks( ActionWithVector*
451452
return active_tasks;
452453
}
453454

455+
bool ActionWithVector::doNotCalculateDerivatives() const {
456+
if( forwardPass ) return true;
457+
return ActionWithValue::doNotCalculateDerivatives();
458+
}
459+
454460
void ActionWithVector::runAllTasks() {
455461
// Skip this if this is done elsewhere
456462
if( action_to_do_before ) return;
@@ -471,6 +477,12 @@ void ActionWithVector::runAllTasks() {
471477
// Now do all preparations required to run all the tasks
472478
// prepareForTaskLoop();
473479

480+
if( !action_to_do_after ) {
481+
forwardPass=true;
482+
for(unsigned i=0; i<getNumberOfComponents(); ++i) {
483+
if( getConstPntrToComponent(i)->getRank()==0 ) { forwardPass=false; break; }
484+
}
485+
}
474486
// Get the total number of streamed quantities that we need
475487
unsigned nquants=0, nmatrices=0, maxcol=0, nbooks=0;
476488
getNumberOfStreamedQuantities( getLabel(), nquants, nmatrices, maxcol, nbooks );
@@ -509,7 +521,7 @@ void ActionWithVector::runAllTasks() {
509521

510522
// MPI Gather everything
511523
if( !serial && buffer.size()>0 ) gatherProcesses( buffer );
512-
finishComputations( buffer );
524+
finishComputations( buffer ); forwardPass=false;
513525
}
514526

515527
void ActionWithVector::gatherThreads( const unsigned& nt, const unsigned& bufsize, const std::vector<double>& omp_buffer, std::vector<double>& buffer, MultiValue& myvals ) {

src/core/ActionWithVector.h

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class ActionWithVector:
3939
private:
4040
/// Is the calculation to be done in serial
4141
bool serial;
42+
/// Are we in the forward pass through the calculation
43+
bool forwardPass;
4244
/// The buffer that we use (we keep a copy here to avoid resizing)
4345
std::vector<double> buffer;
4446
/// The list of active tasks
@@ -119,6 +121,8 @@ class ActionWithVector:
119121
virtual void prepare() override;
120122
void retrieveAtoms( const bool& force=false ) override;
121123
void calculateNumericalDerivatives(ActionWithValue* av) override;
124+
/// Turn off the calculation of the derivatives during the forward pass through a calculation
125+
bool doNotCalculateDerivatives() const override ;
122126
/// Are we running this command in a chain
123127
bool actionInChain() const ;
124128
/// This is overwritten within ActionWithMatrix and is used to build the chain of just matrix actions

0 commit comments

Comments
 (0)