Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions nvmolkit/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def fused_butina(
n_start = x.shape[0]
device = x.device
indices = torch.arange(n_start, dtype=torch.int32, device=device)
# CPU mirror of indices avoids a D2H sync to record each centroid.
indices_host = list(range(n_start))
cluster_count = torch.zeros(2, dtype=torch.int32, device=device)
cluster_count[1] = n_start - 1
cluster_indices = torch.zeros(n_start, dtype=torch.int32, device=device)
Expand All @@ -149,33 +151,55 @@ def fused_butina(
threshold = float(1 - cutoff)
y = x
first_run = True
while cluster_count[0].item() <= cluster_count[1].item() and x.shape[0] > 0:
# cc[0] = next cluster start index, cc[1] = last valid index.
# Initialized to [0, n_start-1] matching cluster_count, so the initial
# while condition is satisfied for n_start > 0 without a D2H sync.
cc = [0, n_start - 1]
while cc[0] <= cc[1] and x.shape[0] > 0:
update_neighbor_counts(x, y, neigh, threshold, subtract=not first_run, metric=metric)
first_run = False

max_val = neigh.max().item()
# Batch max and last-argmax into one D2H transfer (sync 1 of 2).
neigh_flipped = neigh.flip(0)
batch_ma = torch.stack(
[neigh.max().to(torch.int64), neigh_flipped.argmax().to(torch.int64)]
).tolist()
max_val = int(batch_ma[0])
if max_val == 0:
break
id_max = neigh.shape[0] - 1 - neigh.flip(0).contiguous().argmax().item()
centroids.append(indices[id_max].item())
id_max = neigh.shape[0] - 1 - int(batch_ma[1])
centroids.append(indices_host[id_max]) # CPU mirror, no sync

extract_cluster_and_singletons(
x, id_max, is_free, neigh, cluster_count, cluster_indices, threshold, indices, metric=metric
)
cluster_sizes.append(cluster_count[0].item())
x, y = x[is_free, :].contiguous(), x[~is_free, :].contiguous()

# Batch cluster_count and is_free into one D2H transfer (sync 2 of 2).
# combined[:2] = cluster_count, combined[2:] = is_free as int32.
combined = torch.cat([cluster_count, is_free.to(torch.int32)]).tolist()
Comment thread
mooreneural marked this conversation as resolved.
cc = [int(combined[0]), int(combined[1])]
is_free_host = [bool(v) for v in combined[2:]]

cluster_sizes.append(cc[0])

# is_free is already updated in-place on GPU by extract_cluster_and_singletons;
# use it directly to avoid a H2D→CPU→H2D roundtrip.
# is_free_host (from the combined download above) is still needed for indices_host.
y = x[~is_free, :].contiguous()
x = x[is_free, :].contiguous()
indices = indices[is_free].contiguous()
neigh = neigh[is_free].contiguous()
is_free = torch.ones(x.shape[0], dtype=torch.bool, device=x.device)
indices_host = [idx for idx, keep in zip(indices_host, is_free_host) if keep]

cluster_indices_cpu = cluster_indices.cpu()
indices_cpu = cluster_indices_cpu.numpy()
for i in range(n_start - cluster_sizes[-1]):
item = cluster_sizes[-1]
cluster_sizes.append(cluster_sizes[-1] + 1)
centroids.append(cluster_indices_cpu[item].item())
centroids.append(int(indices_cpu[item]))

clusters = []
indices_cpu = cluster_indices.cpu().numpy()
for i in range(len(cluster_sizes) - 1):
start_idx = cluster_sizes[i]
end_idx = cluster_sizes[i + 1]
Expand Down
163 changes: 90 additions & 73 deletions src/conformer_rmsd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,29 @@ __device__ __forceinline__ double det3x3(const double* H) {
// ---------------------------------------------------------------------------

constexpr int kRmsdBlockSize = 128;
using RmsdBlockReduceT = cub::BlockReduce<double, kRmsdBlockSize>;
constexpr int kRmsdWarps = kRmsdBlockSize / 32; // 4 warps per block

using RmsdWarpReduce = cub::WarpReduce<double>;

__device__ __forceinline__ void computePairRmsd(const double* __restrict__ coordI,
const double* __restrict__ coordJ,
const int numAtoms,
const bool prealigned,
double* outRmsd) {
const int tid = threadIdx.x;
__shared__ RmsdBlockReduceT::TempStorage reduceTmp;
const int tid = threadIdx.x;
const int warpId = tid / 32;
const int laneId = tid % 32;

// Shared buffers for warp→block reduction.
// warpReduceTemp[w]: CUB WarpReduce scratch for warp w (empty for full 32-thread warps).
// warpBuf[w][f]: field f's warp-partial sum for warp w.
// sCent[6]: broadcast centroid buffer written by thread 0 after sync 1.
// prealigned path uses only warpBuf[w][0].
// Alignment path phase-1 uses fields 0-5 (centroid sums);
// phase-2 uses fields 0-10 (Sp, Sq, H[0..8]).
__shared__ typename RmsdWarpReduce::TempStorage warpReduceTemp[kRmsdWarps];
__shared__ double warpBuf[kRmsdWarps][11];
__shared__ double sCent[6];

if (prealigned) {
// ---- Simple RMSD without alignment (no centering, matches RDKit behavior) ----
Expand All @@ -118,16 +132,22 @@ __device__ __forceinline__ void computePairRmsd(const double* __restrict__ coord
const double dz = coordI[a * 3 + 2] - coordJ[a * 3 + 2];
sumSqDiff += dx * dx + dy * dy + dz * dz;
}
const double total = RmsdBlockReduceT(reduceTmp).Sum(sumSqDiff);
if (tid == 0)
sumSqDiff = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumSqDiff);
if (laneId == 0)
warpBuf[warpId][0] = sumSqDiff;
__syncthreads();
if (tid == 0) {
double total = 0.0;
for (int w = 0; w < kRmsdWarps; ++w)
total += warpBuf[w][0];
*outRmsd = sqrt(total / static_cast<double>(numAtoms));
}
return;
}

// ---- Kabsch alignment path: compute centroids ----
__shared__ double sCentI[3];
__shared__ double sCentJ[3];

// ---- Kabsch alignment path ----
// Phase 1: accumulate centroid sums, warp-reduce into warpBuf (sync 1),
// thread 0 computes and broadcasts centroids via sCent (sync 2).
double sumIx = 0.0, sumIy = 0.0, sumIz = 0.0;
double sumJx = 0.0, sumJy = 0.0, sumJz = 0.0;
for (int a = tid; a < numAtoms; a += kRmsdBlockSize) {
Expand All @@ -138,43 +158,49 @@ __device__ __forceinline__ void computePairRmsd(const double* __restrict__ coord
sumJy += coordJ[a * 3 + 1];
sumJz += coordJ[a * 3 + 2];
}
sumIx = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumIx);
sumIy = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumIy);
sumIz = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumIz);
sumJx = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumJx);
sumJy = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumJy);
sumJz = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(sumJz);

if (laneId == 0) {
warpBuf[warpId][0] = sumIx;
warpBuf[warpId][1] = sumIy;
warpBuf[warpId][2] = sumIz;
warpBuf[warpId][3] = sumJx;
warpBuf[warpId][4] = sumJy;
warpBuf[warpId][5] = sumJz;
}
__syncthreads(); // sync 1: warpBuf[*][0..5] visible; thread 0 reduces below.

// Reduce centroid components; write results to shared memory for broadcast.
// Each __syncthreads() both allows TempStorage reuse and makes the previous
// shared-memory write visible to all threads before the next reduction.
const double invN = 1.0 / static_cast<double>(numAtoms);
sumIx = RmsdBlockReduceT(reduceTmp).Sum(sumIx);
__syncthreads();
if (tid == 0)
sCentI[0] = sumIx * invN;
sumIy = RmsdBlockReduceT(reduceTmp).Sum(sumIy);
__syncthreads();
if (tid == 0)
sCentI[1] = sumIy * invN;
sumIz = RmsdBlockReduceT(reduceTmp).Sum(sumIz);
__syncthreads();
if (tid == 0)
sCentI[2] = sumIz * invN;
sumJx = RmsdBlockReduceT(reduceTmp).Sum(sumJx);
__syncthreads();
if (tid == 0)
sCentJ[0] = sumJx * invN;
sumJy = RmsdBlockReduceT(reduceTmp).Sum(sumJy);
__syncthreads();
if (tid == 0)
sCentJ[1] = sumJy * invN;
sumJz = RmsdBlockReduceT(reduceTmp).Sum(sumJz);
__syncthreads();
if (tid == 0)
sCentJ[2] = sumJz * invN;
__syncthreads(); // broadcast sCentJ[2] and ensure all centroid writes are visible

const double cIx = sCentI[0], cIy = sCentI[1], cIz = sCentI[2];
const double cJx = sCentJ[0], cJy = sCentJ[1], cJz = sCentJ[2];

// ---- Compute Sp, Sq, and cross-covariance H (Kabsch alignment) ----
if (tid == 0) {
double sIx = 0.0, sIy = 0.0, sIz = 0.0, sJx = 0.0, sJy = 0.0, sJz = 0.0;
for (int w = 0; w < kRmsdWarps; ++w) {
sIx += warpBuf[w][0];
sIy += warpBuf[w][1];
sIz += warpBuf[w][2];
sJx += warpBuf[w][3];
sJy += warpBuf[w][4];
sJz += warpBuf[w][5];
}
sCent[0] = sIx * invN;
sCent[1] = sIy * invN;
sCent[2] = sIz * invN;
sCent[3] = sJx * invN;
sCent[4] = sJy * invN;
sCent[5] = sJz * invN;
}
__syncthreads(); // sync 2: sCent visible to all threads.

const double cIx = sCent[0], cIy = sCent[1], cIz = sCent[2];
const double cJx = sCent[3], cJy = sCent[4], cJz = sCent[5];

// Phase 2: accumulate Sp, Sq, and H = P^T Q (11 values), warp-reduce into
// warpBuf (sync 3), thread 0 sums and computes RMSD.
// Sp = sum ||pi - centI||^2, Sq = sum ||qj - centJ||^2
// H[r][c] = sum (pi[r] - centI[r]) * (qj[c] - centJ[c])
double localSp = 0.0, localSq = 0.0;
double localH[9] = {0.0};

Expand All @@ -189,7 +215,6 @@ __device__ __forceinline__ void computePairRmsd(const double* __restrict__ coord
localSp += px * px + py * py + pz * pz;
localSq += qx * qx + qy * qy + qz * qz;

// H = P^T Q (sum of outer products)
localH[0] += px * qx;
localH[1] += px * qy;
localH[2] += px * qz;
Expand All @@ -201,35 +226,27 @@ __device__ __forceinline__ void computePairRmsd(const double* __restrict__ coord
localH[8] += pz * qz;
}

// Reduce all 11 values into thread 0. Results in non-zero threads are
// undefined and unused; only thread 0 performs the RMSD computation below.
// __syncthreads() between calls allows TempStorage reuse.
localSp = RmsdBlockReduceT(reduceTmp).Sum(localSp);
__syncthreads();
localSq = RmsdBlockReduceT(reduceTmp).Sum(localSq);
__syncthreads();
localH[0] = RmsdBlockReduceT(reduceTmp).Sum(localH[0]);
__syncthreads();
localH[1] = RmsdBlockReduceT(reduceTmp).Sum(localH[1]);
__syncthreads();
localH[2] = RmsdBlockReduceT(reduceTmp).Sum(localH[2]);
__syncthreads();
localH[3] = RmsdBlockReduceT(reduceTmp).Sum(localH[3]);
__syncthreads();
localH[4] = RmsdBlockReduceT(reduceTmp).Sum(localH[4]);
__syncthreads();
localH[5] = RmsdBlockReduceT(reduceTmp).Sum(localH[5]);
__syncthreads();
localH[6] = RmsdBlockReduceT(reduceTmp).Sum(localH[6]);
__syncthreads();
localH[7] = RmsdBlockReduceT(reduceTmp).Sum(localH[7]);
__syncthreads();
localH[8] = RmsdBlockReduceT(reduceTmp).Sum(localH[8]);
// No final sync: only thread 0 reads localSp, localSq, localH below.

// ---- Thread 0: compute RMSD from Sp, Sq, singular values of H ----
localSp = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(localSp);
localSq = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(localSq);
for (int i = 0; i < 9; ++i)
localH[i] = RmsdWarpReduce(warpReduceTemp[warpId]).Sum(localH[i]);

if (laneId == 0) {
warpBuf[warpId][0] = localSp;
warpBuf[warpId][1] = localSq;
for (int i = 0; i < 9; ++i)
warpBuf[warpId][2 + i] = localH[i];
}
__syncthreads(); // sync 3: warpBuf[*][0..10] visible; thread 0 computes RMSD.

if (tid == 0) {
const double* H = localH;
double Sp = 0.0, Sq = 0.0, H[9] = {0.0};
for (int w = 0; w < kRmsdWarps; ++w) {
Sp += warpBuf[w][0];
Sq += warpBuf[w][1];
for (int i = 0; i < 9; ++i)
H[i] += warpBuf[w][2 + i];
}

// G = H^T H (3x3 symmetric positive semi-definite)
const double g00 = H[0] * H[0] + H[3] * H[3] + H[6] * H[6];
Expand All @@ -253,7 +270,7 @@ __device__ __forceinline__ void computePairRmsd(const double* __restrict__ coord
s2 = -s2;

// RMSD^2 = (Sp + Sq - 2*(s0 + s1 + s2)) / N
const double rmsdSq = fmax((localSp + localSq - 2.0 * (s0 + s1 + s2)) * invN, 0.0);
const double rmsdSq = fmax((Sp + Sq - 2.0 * (s0 + s1 + s2)) * invN, 0.0);
*outRmsd = sqrt(rmsdSq);
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/minimizer/bfgs_minimize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,10 @@ __global__ void updateDGradKernel(const double gradTol,
double blockMax = cub::BlockReduce<double, 128>(tempStorage).Reduce(localMax, cubMax());

if (idxWithinSystem == 0) {
// Matches RDKit's signed-energy convergence denominator in ForceField::minimize.
// Negative-energy geometries can clamp this to 1, artificially tightening the
// check — but fixing it unconditionally diverges from reference behavior.
// TODO: file upstream RDKit bug; gate fix on RDKit version once merged there.
const double term = max(energies[sysIdx] * gradScales[sysIdx], 1.0);
Comment thread
mooreneural marked this conversation as resolved.
blockMax /= term;
if (blockMax < gradTol) {
Expand Down
Loading
Loading