@@ -112,45 +112,6 @@ def compute_final_statistics(self):
112
112
return average_squared_distance
113
113
114
114
115
- def compute_average_squared_distance_from_data (
116
- datamodule : pl .LightningDataModule ,
117
- cutoff : float ,
118
- trainer_cfg : Dict [str , Any ],
119
- num_estimation_graphs : int = 5000 ,
120
- verbose : bool = False
121
- ):
122
- """Compute normalization using a Lightning trainer.
123
-
124
- Args:
125
- datamodule: The Lightning datamodule
126
- cutoff (float): The radius cutoff for distance calculations
127
- compute_average_squared_distance_fn (callable): Function to compute average
128
- squared distance for a graph
129
- trainer_cfg: Configuration for the Lightning trainer
130
- num_estimation_graphs (int): Maximum number of graphs to process
131
- verbose (bool): Whether to print detailed statistics
132
-
133
- Returns:
134
- float: The computed average squared distance
135
- """
136
-
137
- # Create the normalization module
138
- norm_module = ComputeNormalizationModule (
139
- cutoff = cutoff ,
140
- num_estimation_graphs = num_estimation_graphs ,
141
- verbose = verbose
142
- )
143
-
144
- # Create the trainer
145
- trainer = hydra .utils .instantiate (trainer_cfg )
146
-
147
- # Fit without any callbacks or loggers
148
- trainer .fit (norm_module , datamodule = datamodule )
149
-
150
- # Compute and return the final statistics
151
- return norm_module .compute_final_statistics ()
152
-
153
-
154
115
def compute_distance_matrix (x : np .ndarray , cutoff : Optional [float ] = None ) -> np .ndarray :
155
116
"""Computes the distance matrix between points in x, ignoring self-distances."""
156
117
if x .shape [- 1 ] != 3 :
@@ -177,42 +138,36 @@ def compute_average_squared_distance(x: np.ndarray, cutoff: Optional[float] = No
177
138
return np .mean (dist_x ** 2 )
178
139
179
140
180
- # def compute_average_squared_distance_from_data(
181
- # dataloader: torch.utils.data.DataLoader,
182
- # cutoff: float,
183
- # num_estimation_graphs: int = 5000,
184
- # verbose: bool = False,
185
- # ) -> float:
186
- # """Computes the average squared distance for normalization."""
187
- # avg_sq_dists = collections.defaultdict(list)
188
- # num_graphs = 0
189
- # for batch in dataloader:
190
- # for graph in batch.to_data_list():
191
- # pos = np.asarray(graph.pos)
192
- # avg_sq_dist = compute_average_squared_distance(pos, cutoff=cutoff)
193
- # avg_sq_dists[graph.dataset_label].append(avg_sq_dist)
194
- # num_graphs += 1
195
-
196
- # if num_graphs >= num_estimation_graphs:
197
- # break
198
-
199
- # mean_avg_sq_dist = sum(np.sum(avg_sq_dists[label]) for label in avg_sq_dists) / num_graphs
200
- # utils.dist_log(f"Mean average squared distance = {mean_avg_sq_dist:0.3f} nm^2")
201
-
202
- # if verbose:
203
- # utils.dist_log(f"For cutoff {cutoff} nm:")
204
- # for label in sorted(avg_sq_dists):
205
- # utils.dist_log(
206
- # f"- Dataset {label}: Average squared distance = {np.mean(avg_sq_dists[label]):0.3f} +- {np.std(avg_sq_dists[label]):0.3f} nm^2"
207
- # )
208
-
209
- # # Average across all processes, if distributed.
210
- # print("torch.distributed.is_initialized():", torch.distributed.is_initialized())
211
- # mean_avg_sq_dist = torch.tensor(mean_avg_sq_dist, device="cuda")
212
-
213
- # print("mean_avg_sq_dist bef:", mean_avg_sq_dist)
214
- # torch.distributed.all_reduce(mean_avg_sq_dist, op=torch.distributed.ReduceOp.AVG)
215
- # mean_avg_sq_dist = mean_avg_sq_dist.item()
216
- # print("mean_avg_sq_dist aft:", mean_avg_sq_dist)
217
-
218
- # return mean_avg_sq_dist
141
+ def compute_average_squared_distance_from_datasets (
142
+ datasets : Sequence [torch .utils .data .Dataset ],
143
+ cutoff : float ,
144
+ num_estimation_datasets : int = 50 ,
145
+ num_estimation_graphs_per_dataset : int = 100 ,
146
+ verbose : bool = False ,
147
+ ) -> float :
148
+ """Computes the average squared distance for normalization."""
149
+ avg_sq_dists = collections .defaultdict (list )
150
+
151
+ for dataset in datasets [:num_estimation_datasets ]:
152
+ num_graphs = 0
153
+
154
+ for graph in dataset :
155
+ pos = np .asarray (graph .pos )
156
+ avg_sq_dist = compute_average_squared_distance (pos , cutoff = cutoff )
157
+ avg_sq_dists [graph .dataset_label ].append (avg_sq_dist )
158
+ num_graphs += 1
159
+
160
+ if num_graphs >= num_estimation_graphs_per_dataset :
161
+ break
162
+
163
+ mean_avg_sq_dist = sum (np .sum (avg_sq_dists [label ]) for label in avg_sq_dists ) / num_graphs
164
+ utils .dist_log (f"Mean average squared distance = { mean_avg_sq_dist :0.3f} nm^2" )
165
+
166
+ if verbose :
167
+ utils .dist_log (f"For cutoff { cutoff } nm:" )
168
+ for label in sorted (avg_sq_dists ):
169
+ utils .dist_log (
170
+ f"- Dataset { label } : Average squared distance = { np .mean (avg_sq_dists [label ]):0.3f} +- { np .std (avg_sq_dists [label ]):0.3f} nm^2"
171
+ )
172
+
173
+ return float (mean_avg_sq_dist )
0 commit comments