@@ -46,6 +46,7 @@ class MatrixTimesVector : public ActionWithMatrix {
46
46
unsigned getNumberOfColumns () const override { plumed_error (); }
47
47
unsigned getNumberOfDerivatives ();
48
48
void prepare () override ;
49
+ void performTask ( const unsigned & task_index, MultiValue& myvals ) const override ;
49
50
bool isInSubChain ( unsigned & nder ) override { nder = arg_deriv_starts[0 ]; return true ; }
50
51
void setupForTask ( const unsigned & task_index, std::vector<unsigned >& indices, MultiValue& myvals ) const ;
51
52
void performTask ( const std::string& controller, const unsigned & index1, const unsigned & index2, MultiValue& myvals ) const override ;
@@ -161,6 +162,74 @@ void MatrixTimesVector::prepare() {
161
162
std::vector<unsigned > shape (1 ); shape[0 ] = getPntrToArgument (0 )->getShape ()[0 ]; myval->setShape (shape);
162
163
}
163
164
165
+ void MatrixTimesVector::performTask ( const unsigned & task_index, MultiValue& myvals ) const {
166
+ if ( actionInChain () ) { ActionWithMatrix::performTask ( task_index, myvals ); return ; }
167
+
168
+ if ( sumrows ) {
169
+ unsigned n=getNumberOfArguments ()-1 ; Value* myvec = getPntrToArgument (n);
170
+ for (unsigned i=0 ; i<n; ++i) {
171
+ Value* mymat = getPntrToArgument (i);
172
+ unsigned ncol = mymat->getNumberOfColumns ();
173
+ unsigned nmat = mymat->getRowLength (task_index);
174
+ double val=0 ; for (unsigned j=0 ; j<nmat; ++j) val += mymat->get ( task_index*ncol + j, false );
175
+ unsigned ostrn = getConstPntrToComponent (i)->getPositionInStream ();
176
+ myvals.setValue ( ostrn, val );
177
+
178
+ // And the derivatives
179
+ if ( doNotCalculateDerivatives () ) continue ;
180
+
181
+ unsigned dloc = arg_deriv_starts[i] + task_index*ncol;
182
+ for (unsigned j=0 ; j<nmat; ++j) {
183
+ myvals.addDerivative ( ostrn, dloc + j, 1.0 ); myvals.updateIndex ( ostrn, dloc + j );
184
+ }
185
+ }
186
+ } else if ( getPntrToArgument (1 )->getRank ()==1 ) {
187
+ Value* mymat = getPntrToArgument (0 );
188
+ unsigned ncol = mymat->getNumberOfColumns ();
189
+ unsigned nmat = mymat->getRowLength (task_index);
190
+ unsigned dloc = arg_deriv_starts[0 ] + task_index*ncol;
191
+ for (unsigned i=0 ; i<getNumberOfArguments ()-1 ; ++i) {
192
+ Value* myvec = getPntrToArgument (i+1 );
193
+ double val=0 ; for (unsigned j=0 ; j<nmat; ++j) val += mymat->get ( task_index*ncol + j, false )*myvec->get ( mymat->getRowIndex ( task_index, j ) );
194
+ unsigned ostrn = getConstPntrToComponent (i)->getPositionInStream ();
195
+ myvals.setValue ( ostrn, val );
196
+
197
+ // And the derivatives
198
+ if ( doNotCalculateDerivatives () ) continue ;
199
+
200
+ for (unsigned j=0 ; j<nmat; ++j) {
201
+ unsigned kind = mymat->getRowIndex ( task_index, j );
202
+ double vecval = myvec->get ( kind );
203
+ double matval = mymat->get ( task_index*ncol + j, false );
204
+ myvals.addDerivative ( ostrn, dloc + j, vecval ); myvals.updateIndex ( ostrn, dloc + j );
205
+ myvals.addDerivative ( ostrn, arg_deriv_starts[i+1 ] + kind, matval ); myvals.updateIndex ( ostrn, arg_deriv_starts[i+1 ] + kind );
206
+ }
207
+ }
208
+ } else {
209
+ unsigned n=getNumberOfArguments ()-1 ; Value* myvec = getPntrToArgument (n);
210
+ for (unsigned i=0 ; i<n; ++i) {
211
+ Value* mymat = getPntrToArgument (i);
212
+ unsigned ncol = mymat->getNumberOfColumns ();
213
+ unsigned nmat = mymat->getRowLength (task_index);
214
+ double val=0 ; for (unsigned j=0 ; j<nmat; ++j) val += mymat->get ( task_index*ncol + j, false )*myvec->get ( mymat->getRowIndex ( task_index, j ) );
215
+ unsigned ostrn = getConstPntrToComponent (i)->getPositionInStream ();
216
+ myvals.setValue ( ostrn, val );
217
+
218
+ // And the derivatives
219
+ if ( doNotCalculateDerivatives () ) continue ;
220
+
221
+ unsigned dloc = arg_deriv_starts[i] + task_index*ncol;
222
+ for (unsigned j=0 ; j<nmat; ++j) {
223
+ unsigned kind = mymat->getRowIndex ( task_index, j );
224
+ double vecval = myvec->get ( kind );
225
+ double matval = mymat->get ( task_index*ncol + j, false );
226
+ myvals.addDerivative ( ostrn, dloc + j, vecval ); myvals.updateIndex ( ostrn, dloc + j );
227
+ myvals.addDerivative ( ostrn, arg_deriv_starts[n] + kind, matval ); myvals.updateIndex ( ostrn, arg_deriv_starts[n] + kind );
228
+ }
229
+ }
230
+ }
231
+ }
232
+
164
233
void MatrixTimesVector::setupForTask ( const unsigned & task_index, std::vector<unsigned >& indices, MultiValue& myvals ) const {
165
234
unsigned start_n = getPntrToArgument (0 )->getShape ()[0 ], size_v = getPntrToArgument (0 )->getRowLength (task_index);
166
235
if ( indices.size ()!=size_v+1 ) indices.resize ( size_v + 1 );
0 commit comments