Skip to content

Commit cd6f553

Browse files
author
Gareth Aneurin Tribello
committed
Added faster version of matrix vector multiply that is used when not employing the chain
1 parent ead56cd commit cd6f553

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/core/ActionWithMatrix.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class ActionWithMatrix : public ActionWithVector {
7575
//// This does some setup before we run over the row of the matrix
7676
virtual void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const = 0;
7777
/// Run over one row of the matrix
78-
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
78+
virtual void performTask( const unsigned& task_index, MultiValue& myvals ) const ;
7979
/// Gather a row of the matrix
8080
void gatherStoredValue( const unsigned& valindex, const unsigned& code, const MultiValue& myvals, const unsigned& bufstart, std::vector<double>& buffer ) const override;
8181
/// Gather all the data from the threads

src/matrixtools/MatrixTimesVector.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class MatrixTimesVector : public ActionWithMatrix {
4646
unsigned getNumberOfColumns() const override { plumed_error(); }
4747
unsigned getNumberOfDerivatives();
4848
void prepare() override ;
49+
void performTask( const unsigned& task_index, MultiValue& myvals ) const override ;
4950
bool isInSubChain( unsigned& nder ) override { nder = arg_deriv_starts[0]; return true; }
5051
void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
5152
void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
@@ -161,6 +162,74 @@ void MatrixTimesVector::prepare() {
161162
std::vector<unsigned> shape(1); shape[0] = getPntrToArgument(0)->getShape()[0]; myval->setShape(shape);
162163
}
163164

165+
void MatrixTimesVector::performTask( const unsigned& task_index, MultiValue& myvals ) const {
166+
if( actionInChain() ) { ActionWithMatrix::performTask( task_index, myvals ); return; }
167+
168+
if( sumrows ) {
169+
unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n);
170+
for(unsigned i=0; i<n; ++i) {
171+
Value* mymat = getPntrToArgument(i);
172+
unsigned ncol = mymat->getNumberOfColumns();
173+
unsigned nmat = mymat->getRowLength(task_index);
174+
double val=0; for(unsigned j=0; j<nmat; ++j) val += mymat->get( task_index*ncol + j, false );
175+
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
176+
myvals.setValue( ostrn, val );
177+
178+
// And the derivatives
179+
if( doNotCalculateDerivatives() ) continue;
180+
181+
unsigned dloc = arg_deriv_starts[i] + task_index*ncol;
182+
for(unsigned j=0; j<nmat; ++j) {
183+
myvals.addDerivative( ostrn, dloc + j, 1.0 ); myvals.updateIndex( ostrn, dloc + j );
184+
}
185+
}
186+
} else if( getPntrToArgument(1)->getRank()==1 ) {
187+
Value* mymat = getPntrToArgument(0);
188+
unsigned ncol = mymat->getNumberOfColumns();
189+
unsigned nmat = mymat->getRowLength(task_index);
190+
unsigned dloc = arg_deriv_starts[0] + task_index*ncol;
191+
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
192+
Value* myvec = getPntrToArgument(i+1);
193+
double val=0; for(unsigned j=0; j<nmat; ++j) val += mymat->get( task_index*ncol + j, false )*myvec->get( mymat->getRowIndex( task_index, j ) );
194+
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
195+
myvals.setValue( ostrn, val );
196+
197+
// And the derivatives
198+
if( doNotCalculateDerivatives() ) continue;
199+
200+
for(unsigned j=0; j<nmat; ++j) {
201+
unsigned kind = mymat->getRowIndex( task_index, j );
202+
double vecval = myvec->get( kind );
203+
double matval = mymat->get( task_index*ncol + j, false );
204+
myvals.addDerivative( ostrn, dloc + j, vecval ); myvals.updateIndex( ostrn, dloc + j );
205+
myvals.addDerivative( ostrn, arg_deriv_starts[i+1] + kind, matval ); myvals.updateIndex( ostrn, arg_deriv_starts[i+1] + kind );
206+
}
207+
}
208+
} else {
209+
unsigned n=getNumberOfArguments()-1; Value* myvec = getPntrToArgument(n);
210+
for(unsigned i=0; i<n; ++i) {
211+
Value* mymat = getPntrToArgument(i);
212+
unsigned ncol = mymat->getNumberOfColumns();
213+
unsigned nmat = mymat->getRowLength(task_index);
214+
double val=0; for(unsigned j=0; j<nmat; ++j) val += mymat->get( task_index*ncol + j, false )*myvec->get( mymat->getRowIndex( task_index, j ) );
215+
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
216+
myvals.setValue( ostrn, val );
217+
218+
// And the derivatives
219+
if( doNotCalculateDerivatives() ) continue;
220+
221+
unsigned dloc = arg_deriv_starts[i] + task_index*ncol;
222+
for(unsigned j=0; j<nmat; ++j) {
223+
unsigned kind = mymat->getRowIndex( task_index, j );
224+
double vecval = myvec->get( kind );
225+
double matval = mymat->get( task_index*ncol + j, false );
226+
myvals.addDerivative( ostrn, dloc + j, vecval ); myvals.updateIndex( ostrn, dloc + j );
227+
myvals.addDerivative( ostrn, arg_deriv_starts[n] + kind, matval ); myvals.updateIndex( ostrn, arg_deriv_starts[n] + kind );
228+
}
229+
}
230+
}
231+
}
232+
164233
void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
165234
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
166235
if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );

0 commit comments

Comments
 (0)