Skip to content

Commit 10da9a4

Browse files
author
Gareth Aneurin Tribello
committed
Merge branch 'master' into derivatives-from-backpropegation
2 parents d15c12f + 2a1e7d3 commit 10da9a4

6 files changed

+51
-22
lines changed

src/adjmat/AdjacencyMatrixBase.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ AdjacencyMatrixBase::AdjacencyMatrixBase(const ActionOptions& ao):
5252
neighbour_list_updated(false),
5353
linkcells(comm),
5454
threecells(comm),
55+
maxcol(0),
5556
natoms_per_list(0)
5657
{
5758
std::vector<unsigned> shape(2); std::vector<AtomNumber> t; parseAtomList("GROUP", t );
@@ -291,9 +292,9 @@ void AdjacencyMatrixBase::setupForTask( const unsigned& current, std::vector<uns
291292
void AdjacencyMatrixBase::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
292293
Vector zero; zero.zero(); plumed_dbg_assert( index2<myvals.getAtomVector().size() );
293294
double weight = calculateWeight( zero, myvals.getAtomVector()[index2], myvals.getNumberOfIndices()-myvals.getSplitIndex(), myvals );
294-
if( fabs(weight)<epsilon ) return;
295-
296295
unsigned w_ind = getConstPntrToComponent(0)->getPositionInStream(); myvals.setValue( w_ind, weight );
296+
if( fabs(weight)<epsilon ) { myvals.setValue( w_ind, 0 ); return; }
297+
297298
if( !doNotCalculateDerivatives() ) {
298299
// Update dynamic list indices for central atom
299300
myvals.updateIndex( w_ind, 3*index1+0 ); myvals.updateIndex( w_ind, 3*index1+1 ); myvals.updateIndex( w_ind, 3*index1+2 );

src/adjmat/ContactMatrix.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ double ContactMatrix::calculateWeight( const Vector& pos1, const Vector& pos2, c
113113
if( mod2<epsilon ) return 0.0; // Atoms can't be bonded to themselves
114114
double dfunc, val = switchingFunction.calculateSqr( mod2, dfunc );
115115
if( val<epsilon ) return 0.0;
116+
if( doNotCalculateDerivatives() ) return val;
116117
addAtomDerivatives( 0, (-dfunc)*distance, myvals );
117118
addAtomDerivatives( 1, (+dfunc)*distance, myvals );
118119
addBoxDerivatives( (-dfunc)*Tensor(distance,distance), myvals );

src/cltools/Benchmark.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ int Benchmark::main(FILE* in, FILE*out,Communicator& pc) {
484484
log.link(log_dev_null.get());
485485
}
486486
log.setLinePrefix("BENCH: ");
487-
487+
log <<"Welcome to PLUMED benchmark\n";
488488
std::vector<Kernel> kernels;
489489

490490
// perform comparative analysis
@@ -589,13 +589,15 @@ int Benchmark::main(FILE* in, FILE*out,Communicator& pc) {
589589
{
590590
std::string paths;
591591
parse("--kernel",paths);
592+
log <<"Using --kernel=" << paths << "\n";
592593
allpaths=Tools::getWords(paths,":");
593594
}
594595

595596
std::vector<std::string> allplumed;
596597
{
597598
std::string paths;
598599
parse("--plumed",paths);
600+
log <<"Using --plumed=" << paths << "\n";
599601
allplumed=Tools::getWords(paths,":");
600602
}
601603

@@ -628,27 +630,36 @@ int Benchmark::main(FILE* in, FILE*out,Communicator& pc) {
628630
// read other flags:
629631
bool shuffled=false;
630632
parseFlag("--shuffled",shuffled);
633+
if (shuffled)
634+
log << "Using --shuffled\n";
631635
int nf; parse("--nsteps",nf);
636+
log << "Using --nsteps=" << nf << "\n";
632637
unsigned natoms; parse("--natoms",natoms);
633-
638+
log << "Using --natoms=" << natoms << "\n";
634639
double maxtime; parse("--maxtime",maxtime);
640+
log << "Using --maxtime=" << maxtime << "\n";
635641

636642
bool domain_decomposition=false;
637643
parseFlag("--domain-decomposition",domain_decomposition);
644+
if (domain_decomposition)
645+
log << "Using --domain-decomposition\n";
646+
638647
if(pc.Get_size()>1) domain_decomposition=true;
639648
if(domain_decomposition) shuffled=true;
640649

641650
double timeToSleep;
642651
parse("--sleep",timeToSleep);
652+
log << "Using --sleep=" << timeToSleep << "\n";
643653

644654
std::vector<int> shuffled_indexes;
645655

646656
{
647657
std::string atomicDistr;
648658
parse("--atom-distribution",atomicDistr);
649659
distribution = getAtomDistribution(atomicDistr);
660+
log << "Using --atom-distribution=" << atomicDistr << "\n";
650661
}
651-
662+
log <<"Initializing the setup of the kernel(s)\n";
652663
const auto initial_time=std::chrono::high_resolution_clock::now();
653664

654665
for(auto & k : kernels) {
@@ -690,7 +701,6 @@ int Benchmark::main(FILE* in, FILE*out,Communicator& pc) {
690701
// trap signals:
691702
SignalHandlerGuard sigIntGuard(SIGINT, signalHandler);
692703

693-
694704
for(int step=0; nf<0 || step<nf; ++step) {
695705
std::shuffle(kernels_ptr.begin(),kernels_ptr.end(),rng);
696706
distribution->positions(pos,step,atomicGenerator);

src/core/ActionWithMatrix.cpp

+11-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ ActionWithMatrix::ActionWithMatrix(const ActionOptions&ao):
3333
ActionWithVector(ao),
3434
next_action_in_chain(NULL),
3535
matrix_to_do_before(NULL),
36-
matrix_to_do_after(NULL)
36+
matrix_to_do_after(NULL),
37+
clearOnEachCycle(true)
3738
{
3839
}
3940

@@ -62,6 +63,14 @@ void ActionWithMatrix::setupStreamedComponents( const std::string& headstr, unsi
6263
myval->setMatrixBookeepingStart(nbookeeping);
6364
nbookeeping += myval->getShape()[0]*( 1 + myval->getNumberOfColumns() );
6465
}
66+
// Turn off clearning of derivatives after each matrix run if there are no matrices in the output of this action
67+
clearOnEachCycle = false;
68+
for(int i=0; i<getNumberOfComponents(); ++i) {
69+
const Value* myval=getConstPntrToComponent(i);
70+
if( myval->getRank()==2 && !myval->hasDerivatives() ) { clearOnEachCycle = true; break; }
71+
}
72+
// Turn off clearing of derivatives if we have only the values of adjacency matrices
73+
if( doNotCalculateDerivatives() && isAdjacencyMatrix() ) clearOnEachCycle = false;
6574
}
6675

6776
void ActionWithMatrix::finishChainBuild( ActionWithVector* act ) {
@@ -222,7 +231,7 @@ void ActionWithMatrix::gatherForcesOnStoredValue( const Value* myval, const unsi
222231
}
223232

224233
void ActionWithMatrix::clearMatrixElements( MultiValue& myvals ) const {
225-
if( isActive() ) {
234+
if( isActive() && clearOnEachCycle ) {
226235
for(int i=0; i<getNumberOfComponents(); ++i) {
227236
const Value* myval=getConstPntrToComponent(i);
228237
if( myval->getRank()==2 && !myval->hasDerivatives() ) myvals.clearDerivatives( myval->getPositionInStream() );

src/core/ActionWithMatrix.h

+4-8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class ActionWithMatrix : public ActionWithVector {
4444
/// This does the calculation of a particular matrix element
4545
void runTask( const std::string& controller, const unsigned& current, const unsigned colno, MultiValue& myvals ) const ;
4646
protected:
47+
/// This turns off derivative clearing for contact matrix if we are not storing derivatives
48+
bool clearOnEachCycle;
4749
/// Does the matrix chain continue on from this action
4850
bool matrixChainContinues() const ;
4951
/// This returns the jelem th element of argument ic
@@ -132,14 +134,8 @@ inline
132134
void ActionWithMatrix::addDerivativeOnMatrixArgument( const bool& inchain, const unsigned& ival, const unsigned& jarg, const unsigned& irow, const unsigned& jcol, const double& der, MultiValue& myvals ) const {
133135
plumed_dbg_assert( jarg<getNumberOfArguments() && getPntrToArgument(jarg)->getRank()==2 && !getPntrToArgument(jarg)->hasDerivatives() );
134136
unsigned ostrn = getConstPntrToComponent(ival)->getPositionInStream(), vstart=arg_deriv_starts[jarg];
135-
if( !inchain && getPntrToArgument(jarg)->getNumberOfColumns()<getPntrToArgument(jarg)->getShape()[1] ) {
136-
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns(); Value* myarg=getPntrToArgument(jarg);
137-
for(unsigned i=0; i<myarg->getRowLength(irow); ++i) {
138-
if( myarg->getRowIndex(irow,i)==jcol ) { myvals.addDerivative( ostrn, dloc+i, der ); myvals.updateIndex( ostrn, dloc+i ); return; }
139-
}
140-
plumed_merror("could not find element of sparse matrix to add derivative to");
141-
} else if( !inchain ) {
142-
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getShape()[1] + jcol;
137+
if( !inchain ) {
138+
unsigned dloc = vstart + irow*getPntrToArgument(jarg)->getNumberOfColumns() + jcol;
143139
myvals.addDerivative( ostrn, dloc, der ); myvals.updateIndex( ostrn, dloc );
144140
} else {
145141
unsigned istrn = getPntrToArgument(jarg)->getPositionInStream();

src/matrixtools/MatrixTimesVector.cpp

+18-6
Original file line numberDiff line numberDiff line change
@@ -152,32 +152,44 @@ void MatrixTimesVector::prepare() {
152152
void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
153153
unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
154154
if( indices.size()!=size_v+1 ) indices.resize( size_v + 1 );
155-
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + getPntrToArgument(0)->getRowIndex( task_index, i );
155+
for(unsigned i=0; i<size_v; ++i) indices[i+1] = start_n + i;
156156
myvals.setSplitIndex( size_v + 1 );
157157
}
158158

159159
void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
160160
unsigned ind2 = index2; if( index2>=getPntrToArgument(0)->getShape()[0] ) ind2 = index2 - getPntrToArgument(0)->getShape()[0];
161161
if( getPntrToArgument(1)->getRank()==1 ) {
162+
double matval = 0; Value* myarg = getPntrToArgument(0); unsigned vcol = ind2;
163+
if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
164+
else {
165+
matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
166+
vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
167+
}
162168
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
163169
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
164-
double matval = getElementOfMatrixArgument( 0, index1, ind2, myvals ), vecval=getArgumentElement( i+1, ind2, myvals );
170+
double vecval=getArgumentElement( i+1, vcol, myvals );
165171
// And add this part of the product
166172
myvals.addValue( ostrn, matval*vecval );
167173
// Now lets work out the derivatives
168174
if( doNotCalculateDerivatives() ) continue;
169-
addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, ind2, matval, myvals );
175+
addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
170176
}
171177
} else {
172-
unsigned n=getNumberOfArguments()-1;
178+
unsigned n=getNumberOfArguments()-1; double matval = 0; unsigned vcol = ind2;
173179
for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
174180
unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
175-
double matval = getElementOfMatrixArgument( i, index1, ind2, myvals ), vecval=getArgumentElement( n, ind2, myvals );
181+
Value* myarg = getPntrToArgument(i);
182+
if( !myarg->valueHasBeenSet() ) matval = myvals.get( myarg->getPositionInStream() );
183+
else {
184+
matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
185+
vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
186+
}
187+
double vecval=getArgumentElement( n, vcol, myvals );
176188
// And add this part of the product
177189
myvals.addValue( ostrn, matval*vecval );
178190
// Now lets work out the derivatives
179191
if( doNotCalculateDerivatives() ) continue;
180-
addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, ind2, matval, myvals );
192+
addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals ); addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
181193
}
182194
}
183195
}

0 commit comments

Comments
 (0)