Skip to content

Commit 033d957

Browse files
Gareth Aneurin TribelloGareth Aneurin Tribello
Gareth Aneurin Tribello
authored and
Gareth Aneurin Tribello
committed
Added optimisation for matrix times vector to avoid allocating large arrays.
1 parent 94d33a9 commit 033d957

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

src/adjmat/AdjacencyMatrixBase.cpp

+10-10
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,10 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
309309
unsigned base = 3*getNumberOfAtoms(); for(unsigned j=0; j<9; ++j) myvals.updateIndex( w_ind, base+j );
310310
// And the indices for the derivatives of the row of the matrix
311311
if( chainContinuesAfterThisAction() ) {
312-
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
313-
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
314-
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
315-
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
312+
unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
313+
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
314+
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
315+
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
316316
}
317317
}
318318

@@ -355,12 +355,12 @@ void AdjacencyMatrixBase::performTask( const std::string& controller, const unsi
355355
myvals.addDerivative( z_index, base+2, -atom[0] ); myvals.addDerivative( z_index, base+5, -atom[1] ); myvals.addDerivative( z_index, base+8, -atom[2] );
356356
for(unsigned k=0; k<9; ++k) { myvals.updateIndex( x_index, base+k ); myvals.updateIndex( y_index, base+k ); myvals.updateIndex( z_index, base+k ); }
357357
if( chainContinuesAfterThisAction() ) {
358-
for(unsigned k=1; k<4; ++k) {
359-
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
360-
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
361-
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
362-
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
363-
}
358+
for(unsigned k=1; k<4; ++k) {
359+
unsigned nmat = getConstPntrToComponent(k)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
360+
std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
361+
matrix_indices[nmat_ind+0]=3*index2+0; matrix_indices[nmat_ind+1]=3*index2+1; matrix_indices[nmat_ind+2]=3*index2+2;
362+
myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+3 );
363+
}
364364
}
365365
}
366366
}

src/core/ActionWithVector.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class ActionWithVector:
136136
/// This is overridden in ActionWithMatrix
137137
virtual void getAllActionLabelsInMatrixChain( std::vector<std::string>& matchain ) const {}
138138
/// Get the number of derivatives in the stream
139-
void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat );
139+
virtual void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat );
140140
/// Get every the label of every value that is calculated in this chain
141141
void getAllActionLabelsInChain( std::vector<std::string>& mylabels ) const ;
142142
/// We override clearInputForces here to ensure that forces are deleted from all values
@@ -186,7 +186,7 @@ bool ActionWithVector::actionInChain() const {
186186
return (action_to_do_before!=NULL);
187187
}
188188

189-
inline
189+
inline
190190
bool ActionWithVector::chainContinuesAfterThisAction() const {
191191
return (action_to_do_after!=NULL);
192192
}

src/matrixtools/MatrixTimesVector.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ class MatrixTimesVector : public ActionWithMatrix {
4444
explicit MatrixTimesVector(const ActionOptions&);
4545
std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
4646
unsigned getNumberOfColumns() const override { plumed_error(); }
47-
unsigned getNumberOfDerivatives();
47+
void getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ) override ;
48+
unsigned getNumberOfDerivatives() override ;
4849
void prepare() override ;
4950
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
5051
bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; }
@@ -162,6 +163,16 @@ void MatrixTimesVector::prepare() {
162163
std::vector<unsigned> shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape);
163164
}
164165

166+
void MatrixTimesVector::getNumberOfStreamedDerivatives( unsigned& nderivatives, Value* stopat ) {
167+
if( actionInChain() ) { ActionWithVector::getNumberOfStreamedDerivatives( nderivatives, stopat ); return; }
168+
169+
nderivatives = 0;
170+
for(unsigned i=0; i<getNumberOfArguments(); ++i) {
171+
arg_deriv_starts[i] = nderivatives;
172+
nderivatives += getPntrToArgument(i)->getNumberOfStoredValues();
173+
}
174+
}
175+
165176
void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const {
166177
if( actionInChain() ) { ActionWithMatrix::performTask( task_index, myvals ); return; }
167178

0 commit comments

Comments
 (0)