diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 6f35f931..a909b346 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -39,6 +39,7 @@ dependencies: - pytest-forked - pytest-xdist - pytorch>=2.3 +- pytorch>=2.3,<2.6a0 - pytorch_geometric>=2.5,<2.7 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - rmm==25.2.*,>=0.0.0a0 diff --git a/conda/environments/all_cuda-121_arch-x86_64.yaml b/conda/environments/all_cuda-121_arch-x86_64.yaml index 51efa1cd..9725d76c 100644 --- a/conda/environments/all_cuda-121_arch-x86_64.yaml +++ b/conda/environments/all_cuda-121_arch-x86_64.yaml @@ -45,6 +45,7 @@ dependencies: - pytest-forked - pytest-xdist - pytorch>=2.3 +- pytorch>=2.3,<2.6a0 - pytorch_geometric>=2.5,<2.7 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - rmm==25.2.*,>=0.0.0a0 diff --git a/conda/environments/all_cuda-124_arch-x86_64.yaml b/conda/environments/all_cuda-124_arch-x86_64.yaml index 3a306e48..860dd11e 100644 --- a/conda/environments/all_cuda-124_arch-x86_64.yaml +++ b/conda/environments/all_cuda-124_arch-x86_64.yaml @@ -45,6 +45,7 @@ dependencies: - pytest-forked - pytest-xdist - pytorch>=2.3 +- pytorch>=2.3,<2.6a0 - pytorch_geometric>=2.5,<2.7 - rapids-build-backend>=0.3.0,<0.4.0.dev0 - rmm==25.2.*,>=0.0.0a0 diff --git a/dependencies.yaml b/dependencies.yaml index 66419b15..6414f698 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -20,6 +20,7 @@ files: - depends_on_dask_cudf - depends_on_cupy - depends_on_pytorch + - depends_on_ogb - depends_on_dgl - depends_on_pyg - python_run_cugraph_dgl @@ -52,6 +53,7 @@ files: includes: - cuda_version - depends_on_pytorch + - depends_on_ogb - depends_on_cugraph_dgl - py_version - test_notebook @@ -63,6 +65,7 @@ files: - depends_on_cudf - depends_on_dgl - depends_on_pytorch + - depends_on_ogb - py_version - test_python_common @@ -123,6 +126,7 @@ files: includes: - depends_on_pylibwholegraph - depends_on_pytorch + - depends_on_ogb - test_python_common py_build_cugraph_pyg: output: pyproject @@ -150,6 +154,7 @@ files: - depends_on_pyg - depends_on_pylibwholegraph - depends_on_pytorch + - depends_on_ogb - test_python_common - test_python_cugraph_pyg @@ -163,6 +168,7 @@ files: - depends_on_cugraph - depends_on_dgl - depends_on_pytorch + - depends_on_ogb - cugraph_dgl_dev - test_python_common cugraph_pyg_dev: @@ -175,6 +181,7 @@ files: - depends_on_cugraph - depends_on_pyg - depends_on_pytorch + - depends_on_ogb - cugraph_pyg_dev - test_python_common channels: @@ -428,6 +435,34 @@ dependencies: - *tensordict - {matrix: null, packages: [*pytorch_pip, *tensordict]} + # Will remove this after snap-stanford/ogb#497 is resolved. + # Temporarily sets the max pytorch version to 2.5 for compatibility + # with ogb. + depends_on_ogb: + common: + - output_types: [conda] + packages: + - pytorch>=2.3,<2.6a0 + specific: + - output_types: [requirements] + matrices: + - matrix: {cuda: "12.*"} + packages: + - --extra-index-url=https://download.pytorch.org/whl/cu121 + - matrix: {cuda: "11.*"} + packages: + - --extra-index-url=https://download.pytorch.org/whl/cu118 + - {matrix: null, packages: null} + - output_types: [requirements, pyproject] + matrices: + - matrix: {cuda: "12.*"} + packages: + - torch>=2.3,<2.6a0 + - matrix: {cuda: "11.*"} + packages: + - torch>=2.3,<2.6a0 + - {matrix: null, packages: [*pytorch_pip]} + depends_on_dgl: specific: - output_types: [conda] diff --git a/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml b/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml index 424852ba..d69b03d7 100644 --- a/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml +++ b/python/cugraph-dgl/conda/cugraph_dgl_dev_cuda-118.yaml @@ -17,6 +17,7 @@ dependencies: - pytest-cov - pytest-xdist - pytorch>=2.3 +- pytorch>=2.3,<2.6a0 - tensordict>=0.1.2 - torchdata name: cugraph_dgl_dev_cuda-118 diff --git a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml index f8e140ac..f4563517 100644 --- a/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml +++ b/python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml @@ -15,6 +15,7 @@ dependencies: - pytest-cov - pytest-xdist - pytorch>=2.3 +- pytorch>=2.3,<2.6a0 - pytorch_geometric>=2.5,<2.7 - tensordict>=0.1.2 - torchdata diff --git a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py index e14e6ca6..589a2306 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/graph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/graph_store.py @@ -293,7 +293,9 @@ def __get_weight_tensor( return torch.concat(weights) @property - def _numeric_edge_types(self) -> Tuple[List, "torch.Tensor", "torch.Tensor"]: + def _numeric_edge_types( + self, + ) -> Tuple[List[Tuple[str, str, str]], "torch.Tensor", "torch.Tensor"]: """ Returns the canonical edge types in order (the 0th canonical type corresponds to numeric edge type 0, etc.), along with the numeric source and destination diff --git a/python/cugraph-pyg/cugraph_pyg/examples/taobao_mnmg.py b/python/cugraph-pyg/cugraph_pyg/examples/taobao_mnmg.py new file mode 100644 index 00000000..368c2fa8 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/examples/taobao_mnmg.py @@ -0,0 +1,541 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import json +import warnings + +import gc + +from datetime import timedelta + +import torch +import torch.nn.functional as F +from torch.nn import Embedding, Linear +from torch.nn.parallel import DistributedDataParallel + +import torch_geometric.transforms as T +from torch_geometric.datasets import Taobao +from torch_geometric.nn import SAGEConv +from torch_geometric.utils.convert import to_scipy_sparse_matrix +from torch_geometric.data import HeteroData + +from pylibwholegraph.torch.initialize import ( + init as wm_init, + finalize as wm_finalize, +) + +from sklearn.metrics import roc_auc_score + +# Allow computation on objects that are larger than GPU memory +# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory +os.environ["CUDF_SPILL"] = "1" + +# Ensures that a CUDA context is not created on import of rapids. +# Allows pytorch to create the context instead +os.environ["RAPIDS_NO_INITIALIZE"] = "1" + + +def init_pytorch_worker(global_rank, local_rank, world_size, cugraph_id): + import rmm + + rmm.reinitialize( + devices=local_rank, + managed_memory=True, + pool_allocator=True, + ) + + import cupy + + cupy.cuda.Device(local_rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + + enable_spilling() + + torch.cuda.set_device(local_rank) + + from cugraph.gnn import cugraph_comms_init + + cugraph_comms_init( + rank=global_rank, world_size=world_size, uid=cugraph_id, device=local_rank + ) + + wm_init(global_rank, world_size, local_rank, torch.cuda.device_count()) + + +class ItemGNNEncoder(torch.nn.Module): + def __init__(self, hidden_channels, out_channels): + super().__init__() + self.conv1 = SAGEConv(-1, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, hidden_channels) + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index).relu() + x = self.conv2(x, edge_index).relu() + return self.lin(x) + + +class UserGNNEncoder(torch.nn.Module): + def __init__(self, hidden_channels, out_channels): + super().__init__() + self.conv1 = SAGEConv((-1, -1), hidden_channels) + self.conv2 = SAGEConv((-1, -1), hidden_channels) + self.conv3 = SAGEConv((-1, -1), hidden_channels) + self.lin = Linear(hidden_channels, out_channels) + + def forward(self, x_dict, edge_index_dict): + item_x = self.conv1( + x_dict["item"], + edge_index_dict[("item", "to", "item")], + ).relu() + + user_x = self.conv2( + (x_dict["item"], x_dict["user"]), + edge_index_dict[("item", "rev_to", "user")], + ).relu() + + user_x = self.conv3( + (item_x, user_x), + edge_index_dict[("item", "rev_to", "user")], + ).relu() + + return self.lin(user_x) + + +class EdgeDecoder(torch.nn.Module): + def __init__(self, hidden_channels): + super().__init__() + self.lin1 = Linear(2 * hidden_channels, hidden_channels) + self.lin2 = Linear(hidden_channels, 1) + + def forward(self, z_src, z_dst, edge_label_index): + row, col = edge_label_index + z = torch.cat([z_src[row], z_dst[col]], dim=-1) + + z = self.lin1(z).relu() + z = self.lin2(z) + return z.view(-1) + + +class Model(torch.nn.Module): + def __init__(self, num_users, num_items, hidden_channels, out_channels): + super().__init__() + self.user_emb = Embedding(num_users, hidden_channels) + self.item_emb = Embedding(num_items, hidden_channels) + self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels) + self.user_encoder = UserGNNEncoder(hidden_channels, out_channels) + self.decoder = EdgeDecoder(out_channels) + + def forward(self, x_dict, edge_index_dict, edge_label_index): + z_dict = {} + x_dict["user"] = self.user_emb(x_dict["user"]) + x_dict["item"] = self.item_emb(x_dict["item"]) + z_dict["item"] = self.item_encoder( + x_dict["item"], + edge_index_dict[("item", "to", "item")], + ) + z_dict["user"] = self.user_encoder(x_dict, edge_index_dict) + + return self.decoder(z_dict["user"], z_dict["item"], edge_label_index) + + +def write_edges(edge_index, path): + world_size = torch.distributed.get_world_size() + + os.makedirs(path, exist_ok=True) + for (r, e) in enumerate(torch.tensor_split(edge_index, world_size, dim=1)): + rank_path = os.path.join(path, f"rank={r}.pt") + torch.save( + e.clone(), + rank_path, + ) + + +def preprocess_and_partition(data, edge_path, meta_path): + # Only interested in user/item edges + del data["category"] + del data["item", "category"] + del data["user", "item"].time + del data["user", "item"].behavior + + print("Writing item->item edge partitions...") + item_item_edge_path = os.path.join(edge_path, "item_item") + write_edges(data["item", "item"].edge_index, item_item_edge_path) + + print("Writing user->item edge partitions...") + user_item_edge_path = os.path.join(edge_path, "user_item") + write_edges(data["user", "item"].edge_index, user_item_edge_path) + + print("Writing metadata...") + meta = { + "num_nodes": { + "item": data["item"].num_nodes, + "user": data["user"].num_nodes, + } + } + with open(meta_path, "w") as f: + json.dump(meta, f) + + +def pre_transform(data): + # Compute item->item relationships: + print("Computing item->item relationships (this may take a very long time)...") + mat = to_scipy_sparse_matrix(data["user", "item"].edge_index).tocsr() + mat = mat[: data["user"].num_nodes, : data["item"].num_nodes] + comat = mat.T @ mat + comat.setdiag(0) + comat = comat >= 3.0 + comat = comat.tocoo() + row = torch.from_numpy(comat.row).to(torch.long) + col = torch.from_numpy(comat.col).to(torch.long) + data["item", "item"].edge_index = torch.stack([row, col], dim=0) + return data + + +def cugraph_pyg_from_heterodata(data, wg_mem_type, return_edge_label=True): + from cugraph_pyg.data import GraphStore, WholeFeatureStore + + graph_store = GraphStore(is_multi_gpu=True) + feature_store = WholeFeatureStore(memory_type=wg_mem_type) + + graph_store[ + ("user", "to", "item"), + "coo", + False, + (data["user"].num_nodes, data["item"].num_nodes), + ] = data["user", "to", "item"].edge_index + graph_store[ + ("item", "rev_to", "user"), + "coo", + False, + (data["item"].num_nodes, data["user"].num_nodes), + ] = data["item", "rev_to", "user"].edge_index + + graph_store[ + ("item", "to", "item"), + "coo", + False, + (data["item"].num_nodes, data["item"].num_nodes), + ] = data["item", "to", "item"].edge_index + graph_store[ + ("item", "rev_to", "item"), + "coo", + False, + (data["item"].num_nodes, data["item"].num_nodes), + ] = data["item", "rev_to", "item"].edge_index + + feature_store["item", "x", None] = data["item"].x + feature_store["user", "x", None] = data["user"].x + + out = ( + (feature_store, graph_store), + data["user", "to", "item"].edge_label_index, + (data["user", "to", "item"].edge_label if return_edge_label else None), + ) + + return out + + +def load_partitions(edge_path, meta_path, wg_mem_type): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + data = HeteroData() + + # Load metadata + print("Loading metadata...") + with open(meta_path, "r") as f: + meta = json.load(f) + + data["user"].num_nodes = meta["num_nodes"]["user"] + data["item"].num_nodes = meta["num_nodes"]["item"] + + data["user"].x = torch.tensor_split( + torch.arange(data["user"].num_nodes), world_size + )[rank] + + data["item"].x = torch.tensor_split( + torch.arange(data["item"].num_nodes), world_size + )[rank] + + # T.ToUndirected() will not work here because we are working with + # partitioned data. The number of nodes will not match. + + print("Loading item->item edge index...") + data["item", "to", "item"].edge_index = torch.load( + os.path.join(edge_path, "item_item", f"rank={rank}.pt"), + weights_only=True, + ) + data["item", "rev_to", "item"].edge_index = torch.stack( + [ + data["item", "to", "item"].edge_index[1], + data["item", "to", "item"].edge_index[0], + ] + ) + + print("Loading user->item edge index...") + data["user", "to", "item"].edge_index = torch.load( + os.path.join(edge_path, "user_item", f"rank={rank}.pt"), + weights_only=True, + ) + data["item", "rev_to", "user"].edge_index = torch.stack( + [ + data["user", "to", "item"].edge_index[1], + data["user", "to", "item"].edge_index[0], + ] + ) + + # Generate data splits here + print("Splitting data...") + train_data, val_data, test_data = T.RandomLinkSplit( + num_val=0.1, + num_test=0.1, + neg_sampling_ratio=1.0, + add_negative_train_samples=False, + edge_types=[("user", "to", "item")], + rev_edge_types=[("item", "rev_to", "user")], + )(data) + + print(train_data, test_data, val_data) + + print(f"Finished loading graph data on rank {rank}") + return { + "train": cugraph_pyg_from_heterodata( + train_data, wg_mem_type, return_edge_label=False + ), + "test": cugraph_pyg_from_heterodata(test_data, wg_mem_type), + "val": cugraph_pyg_from_heterodata(val_data, wg_mem_type), + }, meta + + +def train(model, optimizer, loader): + rank = torch.distributed.get_rank() + model.train() + + total_loss = total_examples = 0 + for i, batch in enumerate(loader): + batch = batch.to(rank) + optimizer.zero_grad() + + if i % 10 == 0 and rank == 0: + print(f"iter {i}") + + pred = model( + batch.x_dict, + batch.edge_index_dict, + batch["user", "item"].edge_label_index, + ) + loss = F.binary_cross_entropy_with_logits( + pred, batch["user", "item"].edge_label + ) + + loss.backward() + optimizer.step() + total_loss += float(loss) + total_examples += pred.numel() + + return total_loss / total_examples + + +@torch.no_grad() +def test(model, loader): + rank = torch.distributed.get_rank() + + model.eval() + preds, targets = [], [] + for i, batch in enumerate(loader): + batch = batch.to(rank) + + pred = ( + model( + batch.x_dict, + batch.edge_index_dict, + batch["user", "item"].edge_label_index, + ) + .sigmoid() + .view(-1) + .cpu() + ) + target = batch["user", "item"].edge_label.long().cpu() + + preds.append(pred) + targets.append(target) + + pred = torch.cat(preds, dim=0).numpy() + target = torch.cat(targets, dim=0).numpy() + + return roc_auc_score(target, pred) + + +if __name__ == "__main__": + if "LOCAL_RANK" not in os.environ: + warnings.warn("This script should be run with 'torchrun`. Exiting.") + exit() + + parser = argparse.ArgumentParser() + parser.add_argument("--lr", type=float, default=0.001) + parser.add_argument("--epochs", type=int, default=21) + parser.add_argument("--batch_size", type=int, default=2048) + parser.add_argument("--dataset_root", type=str, default="datasets") + parser.add_argument("--skip_partition", action="store_true") + parser.add_argument("--wg_mem_type", type=str, default="distributed") + args = parser.parse_args() + + dataset_name = "taobao" + + torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=3600)) + world_size = torch.distributed.get_world_size() + global_rank = torch.distributed.get_rank() + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(local_rank) + + if global_rank == 0: + from rmm.allocators.torch import rmm_torch_allocator + + torch.cuda.change_current_allocator(rmm_torch_allocator) + + # Create the uid needed for cuGraph comms + if global_rank == 0: + from cugraph.gnn import ( + cugraph_comms_create_unique_id, + ) + + cugraph_id = [cugraph_comms_create_unique_id()] + else: + cugraph_id = [None] + torch.distributed.broadcast_object_list(cugraph_id, src=0, device=device) + cugraph_id = cugraph_id[0] + + init_pytorch_worker(global_rank, local_rank, world_size, cugraph_id) + + # Split the data + edge_path = os.path.join(args.dataset_root, dataset_name + "_eix_part") + feature_path = os.path.join(args.dataset_root, dataset_name + "_fea_part") + label_path = os.path.join(args.dataset_root, dataset_name + "_label_part") + meta_path = os.path.join(args.dataset_root, dataset_name + "_meta.json") + + if not args.skip_partition and global_rank == 0: + print("Partitioning data...") + + dataset = Taobao(args.dataset_root, pre_transform=pre_transform) + data = dataset[0] + + preprocess_and_partition( + data, + edge_path=edge_path, + meta_path=meta_path, + ) + + print("Data partitioning complete!") + + torch.distributed.barrier() + data_dict, meta = load_partitions(edge_path, meta_path, args.wg_mem_type) + torch.distributed.barrier() + + from cugraph_pyg.loader import LinkNeighborLoader + + def create_loader(data_l): + return LinkNeighborLoader( + data=data_l[0], + edge_label_index=data_l[1], + edge_label=data_l[2], + neg_sampling="binary" if data_l[2] is None else None, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_neighbors={ + ("user", "to", "item"): [8, 4], + ("item", "rev_to", "user"): [8, 4], + ("item", "to", "item"): [8, 4], + ("item", "rev_to", "item"): [8, 4], + }, + local_seeds_per_call=16384, + ) + + print("Creating train loader...") + train_loader = create_loader( + data_dict["train"], + ) + print(f"Created train loader on rank {global_rank}") + + torch.distributed.barrier() + + print("Creating validation loader...") + val_loader = create_loader( + data_dict["val"], + ) + print(f"Created validation loader on rank {global_rank}") + + torch.distributed.barrier() + + model = Model( + num_users=meta["num_nodes"]["user"], + num_items=meta["num_nodes"]["item"], + hidden_channels=64, + out_channels=64, + ).to(local_rank) + print(f"Created model on rank {global_rank}") + + # Initialize lazy modules + # FIXME DO NOT DO THIS!!!! Use set parameters + for batch in train_loader: + batch = batch.to(local_rank) + _ = model( + batch.x_dict, + batch.edge_index_dict, + batch["user", "item"].edge_label_index, + ) + break + print(f"Initialized model on rank {global_rank}") + + model = DistributedDataParallel(model, device_ids=[local_rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + best_val_auc = 0 + for epoch in range(1, args.epochs + 1): + print("Train") + loss = train(model, optimizer, train_loader) + + if global_rank == 0: + print("Val") + val_auc = test(model, val_loader) + best_val_auc = max(best_val_auc, val_auc) + + if global_rank == 0: + print(f"Epoch: {epoch:02d}, Loss: {loss:4f}, Val AUC: {val_auc:.4f}") + + del train_loader + del val_loader + gc.collect() + print("Creating test loader...") + test_loader = create_loader(data_dict["test"]) + + if global_rank == 0: + print("Test") + test_auc = test(model, test_loader) + print( + f"Total {args.epochs:02d} epochs: Final Loss: {loss:4f}, " + f"Best Val AUC: {best_val_auc:.4f}, " + f"Test AUC: {test_auc:.4f}" + ) + + wm_finalize() + + from cugraph.gnn import cugraph_comms_shutdown + + cugraph_comms_shutdown() diff --git a/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py index 2effdab9..2c45e6c5 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/link_neighbor_loader.py @@ -15,6 +15,8 @@ from typing import Union, Tuple, Optional, Callable, List, Dict +import numpy as np + import cugraph_pyg from cugraph_pyg.loader import LinkLoader from cugraph_pyg.sampler import BaseSampler @@ -218,6 +220,17 @@ def __init__( if weight_attr is not None: graph_store._set_weight_attr((feature_store, weight_attr)) + if isinstance(num_neighbors, dict): + sorted_keys, _, _ = graph_store._numeric_edge_types + fanout_length = len(next(iter(num_neighbors.values()))) + na = np.zeros(fanout_length * len(sorted_keys), dtype="int32") + for i, key in enumerate(sorted_keys): + if key in num_neighbors: + for hop in range(fanout_length): + na[hop * len(sorted_keys) + i] = num_neighbors[key][hop] + + num_neighbors = na + sampler = BaseSampler( NeighborSampler( graph_store._graph, diff --git a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py index 961ac34a..478c90c0 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py @@ -15,6 +15,8 @@ from typing import Union, Tuple, Optional, Callable, List, Dict +import numpy as np + import cugraph_pyg from cugraph_pyg.loader import NodeLoader from cugraph_pyg.sampler import BaseSampler @@ -211,6 +213,17 @@ def __init__( if weight_attr is not None: graph_store._set_weight_attr((feature_store, weight_attr)) + if isinstance(num_neighbors, dict): + sorted_keys, _, _ = graph_store._numeric_edge_types + fanout_length = len(next(iter(num_neighbors.values()))) + na = np.zeros(fanout_length * len(sorted_keys), dtype="int32") + for i, key in enumerate(sorted_keys): + if key in num_neighbors: + for hop in range(fanout_length): + na[hop * len(sorted_keys) + i] = num_neighbors[key][hop] + + num_neighbors = na + sampler = BaseSampler( NeighborSampler( graph_store._graph, diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py index 629d56af..ea67cd56 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler.py @@ -365,7 +365,11 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): ux = col[pyg_can_etype][: num_sampled_edges[pyg_can_etype][0]] if ux.numel() > 0: - input_type = pyg_can_etype[2] # can only ever be 1 + # can only ever be 1 + if "edge_inverse" in raw_sample_data: + input_type = pyg_can_etype + else: + input_type = pyg_can_etype[2] num_sampled_nodes[self.__dst_types[etype]][0] = torch.max( num_sampled_nodes[self.__dst_types[etype]][0], @@ -383,13 +387,37 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): } num_sampled_edges = {k: v.cpu() for k, v in num_sampled_edges.items()} + input_index = raw_sample_data["input_index"][ + raw_sample_data["input_offsets"][index] : raw_sample_data["input_offsets"][ + index + 1 + ] + ] + + num_seeds = input_index.numel() + input_index = input_index[input_index >= 0] + + num_pos = input_index.numel() + num_neg = num_seeds - num_pos + if num_neg > 0: + edge_label = torch.concat( + [ + torch.full((num_pos,), 1.0), + torch.full((num_neg,), 0.0), + ] + ) + else: + if "input_label" in raw_sample_data: + edge_label = raw_sample_data["input_label"][ + raw_sample_data["input_offsets"][index] : raw_sample_data[ + "input_offsets" + ][index + 1] + ] + else: + edge_label = None + input_index = ( input_type, - raw_sample_data["input_index"][ - raw_sample_data["input_offsets"][index] : raw_sample_data[ - "input_offsets" - ][index + 1] - ], + input_index, ) edge_inverse = ( @@ -413,7 +441,7 @@ def __decode_coo(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): metadata = ( input_index, edge_inverse.view(2, -1), - None, + edge_label, None, # TODO this will eventually include time ) @@ -524,7 +552,14 @@ def __decode_csc(self, raw_sample_data: Dict[str, "torch.Tensor"], index: int): ] ) else: - edge_label = None + if "input_label" in raw_sample_data: + edge_label = raw_sample_data["input_label"][ + raw_sample_data["input_offsets"][index] : raw_sample_data[ + "input_offsets" + ][index + 1] + ] + else: + edge_label = None edge_inverse = ( ( @@ -748,6 +783,7 @@ def sample_from_edges( reader = self.__sampler.sample_from_edges( torch.stack([src, dst]), # reverse of usual convention input_id=input_id, + input_label=index.label, batch_size=self.__batch_size + neg_batch_size, **kwargs, ) @@ -765,5 +801,6 @@ def sample_from_edges( src_types=src_types, dst_types=dst_types, edge_types=edge_types, + vertex_types=sorted(self.__graph_store._num_vertices().keys()), vertex_offsets=self.__graph_store._vertex_offset_array, ) diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py index b3d56ef9..a0021812 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/sampler_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -459,18 +459,11 @@ def neg_sample( int(ceil(seed_src.numel() / batch_size)), ) - if graph_store.is_multi_gpu: - num_neg_global = torch.tensor([num_neg], device="cuda") - torch.distributed.all_reduce(num_neg_global, op=torch.distributed.ReduceOp.SUM) - num_neg = int(num_neg_global) - else: - num_neg_global = num_neg - if node_time is None: result_dict = pylibcugraph.negative_sampling( graph_store._resource_handle, graph_store._graph, - num_neg_global, + num_neg, vertices=None if unweighted else cupy.arange(src_weight.numel(), dtype="int64"), diff --git a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py index 831ee0d6..9eb77f78 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/loader/test_neighbor_loader_mg.py @@ -94,9 +94,9 @@ def run_test_neighbor_loader_mg(rank, uid, world_size, specify_size): assert (feature_store["person", "feat", None][batch.n_id] == batch.feat).all() cugraph_comms_shutdown() + torch.distributed.destroy_process_group() -@pytest.mark.skip(reason="deleteme") @pytest.mark.parametrize("specify_size", [True, False]) @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.mg @@ -164,6 +164,7 @@ def run_test_neighbor_loader_biased_mg(rank, uid, world_size): ).all() cugraph_comms_shutdown() + torch.distributed.destroy_process_group() @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @@ -223,6 +224,7 @@ def run_test_link_neighbor_loader_basic_mg( assert (elx[i] == batch.n_id[batch.edge_label_index.cpu()]).all() cugraph_comms_shutdown() + torch.distributed.destroy_process_group() @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @@ -284,6 +286,7 @@ def run_test_link_neighbor_loader_uneven_mg( assert (elx[:, [i]] == batch.n_id[batch.edge_label_index.cpu()]).all() cugraph_comms_shutdown() + torch.distributed.destroy_process_group() @pytest.mark.skip(reason="broken") @@ -346,6 +349,9 @@ def run_test_link_neighbor_loader_negative_sampling_basic_mg( for i, batch in enumerate(loader): assert batch.edge_label[0] == 1.0 + cugraph_comms_shutdown() + torch.distributed.destroy_process_group() + @pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") @pytest.mark.mg