Skip to content

Commit 2115528

Browse files
Gareth Aneurin TribelloGareth Aneurin Tribello
Gareth Aneurin Tribello
authored and
Gareth Aneurin Tribello
committed
Optimisations of derivatives by back propegation
1 parent 033d957 commit 2115528

9 files changed

+30
-27
lines changed

src/adjmat/AdjacencyMatrixBase.cpp

+10-14
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,10 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
308308
// Update dynamic list indices for virial
309309
unsigned base = 3*getNumberOfAtoms(); for(unsigned j=0; j<9; ++j) myvals.updateIndex( w_ind, base+j );
310310
// And the indices for the derivatives of the row of the matrix
311-
if( chainContinuesAfterThisAction() ) {
312-
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
313-
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
314-
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
315-
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
316-
}
311+
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
312+
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
313+
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
314+
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
317315
}
318316

319317
// Calculate the components if we need them
@@ -354,20 +352,18 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
354352
myvals.addDerivative( z_index, base+1, 0 ); myvals.addDerivative( z_index, base+4, 0 ); myvals.addDerivative( z_index, base+7, 0 );
355353
myvals.addDerivative( z_index, base+2, -atom[0] ); myvals.addDerivative( z_index, base+5, -atom[1] ); myvals.addDerivative( z_index, base+8, -atom[2] );
356354
for(unsigned k=0; k<9; ++k) { myvals.updateIndex( x_index, base+k ); myvals.updateIndex( y_index, base+k ); myvals.updateIndex( z_index, base+k ); }
357-
if( chainContinuesAfterThisAction() ) {
358-
for(unsigned k=1; k<4; ++k) {
359-
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
360-
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
361-
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
362-
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
363-
}
355+
for(unsigned k=1; k<4; ++k) {
356+
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
357+
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
358+
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
359+
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
364360
}
365361
}
366362
}
367363
}
368364

369365
void AdjacencyMatrixBase::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
370-
if( doNotCalculateDerivatives() || !chainContinuesAfterThisAction() ) return;
366+
if( doNotCalculateDerivatives() ) return;
371367

372368
for(int k=0; k<getNumberOfComponents(); ++k) {
373369
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );

src/adjmat/TorsionsMatrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ void TorsionsMatrix::performTask( const std::string& controller, const unsigned&
138138
}
139139

140140
void TorsionsMatrix::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
141-
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
141+
if( doNotCalculateDerivatives() ) return ;
142142

143143
unsigned mat1s = 3*ival, ss = getPntrToArgument(1)->getShape()[1];
144144
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );

src/core/ActionWithMatrix.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ bool ActionWithMatrix::checkForTaskForce( const unsigned& itask, const Value* my
226226

227227
void ActionWithMatrix::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
228228
if( myval->getRank()==1 ) { ActionWithVector::gatherForcesOnStoredValue( myval, itask, myvals, forces ); return; }
229-
unsigned matind = myval->getPositionInMatrixStash();
230-
for(unsigned j=0; j<forces.size(); ++j) forces[j] += myvals.getStashedMatrixForce( matind, j );
229+
unsigned matind = myval->getPositionInMatrixStash(); const std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( matind ) );
230+
for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(matind); ++i) { unsigned kind = mat_indices[i]; forces[kind] += myvals.getStashedMatrixForce( matind, kind ); }
231231
}
232232

233233
void ActionWithMatrix::clearMatrixElements( MultiValue& myvals ) const {

src/core/ActionWithMatrix.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class ActionWithMatrix : public ActionWithVector {
9393
/// Check if there are forces we need to account for on this task
9494
bool checkForTaskForce( const unsigned& itask, const Value* myval ) const override ;
9595
/// This gathers the force on a particular value
96-
void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override;
96+
virtual void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const;
9797
};
9898

9999
inline

src/core/ActionWithVector.h

-6
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ class ActionWithVector:
125125
bool doNotCalculateDerivatives() const override ;
126126
/// Are we running this command in a chain
127127
bool actionInChain() const ;
128-
bool chainContinuesAfterThisAction() const ;
129128
/// This is overwritten within ActionWithMatrix and is used to build the chain of just matrix actions
130129
virtual void finishChainBuild( ActionWithVector* act );
131130
/// Check if there are any stored values in arguments
@@ -186,11 +185,6 @@ bool ActionWithVector::actionInChain() const {
186185
return (action_to_do_before!=NULL);
187186
}
188187

189-
inline
190-
bool ActionWithVector::chainContinuesAfterThisAction() const {
191-
return (action_to_do_after!=NULL);
192-
}
193-
194188
inline
195189
bool ActionWithVector::runInSerial() const {
196190
return serial;

src/matrixtools/MatrixTimesMatrix.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ void MatrixTimesMatrix::performTask( const std::string& controller, const unsign
149149
}
150150

151151
void MatrixTimesMatrix::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
152-
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
152+
if( doNotCalculateDerivatives() ) return ;
153153

154154
unsigned mat1s = ival*getPntrToArgument(0)->getShape()[1];
155155
unsigned nmult = getPntrToArgument(0)->getShape()[1], ss = getPntrToArgument(1)->getShape()[1];

src/matrixtools/OuterProduct.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ void OuterProduct::performTask( const std::string& controller, const unsigned& i
146146
addDerivativeOnVectorArgument( stored_vector1, 0, 0, index1, function.evaluateDeriv( 0, args ), myvals );
147147
addDerivativeOnVectorArgument( stored_vector2, 0, 1, ind2, function.evaluateDeriv( 1, args ), myvals );
148148
}
149-
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
149+
if( doNotCalculateDerivatives() ) return ;
150150
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
151151
myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = arg_deriv_starts[1] + ind2; myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
152152
}
153153

154154
void OuterProduct::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
155-
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
155+
if( doNotCalculateDerivatives() ) return ;
156156
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
157157
myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = ival; myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
158158
}

src/tools/MultiValue.h

+6
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class MultiValue {
144144
void setNumberOfMatrixRowDerivatives( const unsigned& nmat, const unsigned& nind );
145145
unsigned getNumberOfMatrixRowDerivatives( const unsigned& nmat ) const ;
146146
std::vector<unsigned>& getMatrixRowDerivativeIndices( const unsigned& nmat );
147+
const std::vector<unsigned>& getMatrixRowDerivativeIndices( const unsigned& nmat ) const ;
147148
/// Stash the forces on the matrix
148149
void addMatrixForce( const unsigned& imat, const unsigned& jind, const double& f );
149150
double getStashedMatrixForce( const unsigned& imat, const unsigned& jind ) const ;
@@ -335,6 +336,11 @@ std::vector<unsigned>& MultiValue::getMatrixRowDerivativeIndices( const unsigned
335336
plumed_dbg_assert( nmat<matrix_row_nderivatives.size() ); return matrix_row_derivative_indices[nmat];
336337
}
337338

339+
inline
340+
const std::vector<unsigned>& MultiValue::getMatrixRowDerivativeIndices( const unsigned& nmat ) const {
341+
plumed_dbg_assert( nmat<matrix_row_nderivatives.size() ); return matrix_row_derivative_indices[nmat];
342+
}
343+
338344
inline
339345
void MultiValue::addMatrixForce( const unsigned& imat, const unsigned& jind, const double& f ) {
340346
matrix_force_stash[imat*nderivatives + jind]+=f;

src/valtools/VStack.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class VStack : public ActionWithMatrix {
5555
void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
5656
///
5757
void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ;
58+
///
59+
void gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const override ;
5860
};
5961

6062
PLUMED_REGISTER_ACTION(VStack,"VSTACK")
@@ -137,6 +139,11 @@ void VStack::performTask( const std::string& controller, const unsigned& index1,
137139
addDerivativeOnVectorArgument( stored[ind2], 0, ind2, index1, 1.0, myvals );
138140
}
139141

142+
void VStack::gatherForcesOnStoredValue( const Value* myval, const unsigned& itask, const MultiValue& myvals, std::vector<double>& forces ) const {
143+
unsigned matind = myval->getPositionInMatrixStash(); const std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( matind ) );
144+
for(unsigned i=0; i<forces.size(); ++i) forces[i] += myvals.getStashedMatrixForce( matind, i );
145+
}
146+
140147
void VStack::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
141148
if( doNotCalculateDerivatives() || !matrixChainContinues() ) return ;
142149

0 commit comments

Comments
 (0)