diff --git a/graphlearn_torch/python/distributed/dist_neighbor_loader.py b/graphlearn_torch/python/distributed/dist_neighbor_loader.py index 9eae94fd..9d93d8e8 100644 --- a/graphlearn_torch/python/distributed/dist_neighbor_loader.py +++ b/graphlearn_torch/python/distributed/dist_neighbor_loader.py @@ -17,8 +17,14 @@ import torch -from ..sampler import NodeSamplerInput, SamplingType, SamplingConfig, \ - RemoteNodeSplitSamplerInput, RemoteNodePathSamplerInput +from ..sampler import ( + AllNodeSamplerInputs, + NodeSamplerInput, + SamplingType, + SamplingConfig, + RemoteNodeSplitSamplerInput, + RemoteNodePathSamplerInput, +) from ..typing import InputNodes, NumNeighbors, Split from .dist_dataset import DistDataset @@ -27,7 +33,7 @@ class DistNeighborLoader(DistLoader): - r""" A distributed loader that preform sampling from nodes. + r"""A distributed loader that preform sampling from nodes. Args: data (DistDataset, optional): The ``DistDataset`` object of a partition of @@ -37,10 +43,14 @@ class DistNeighborLoader(DistLoader): The number of neighbors to sample for each node in each iteration. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The node seeds for + input_nodes (torch.Tensor or Tuple[str, torch.Tensor], optional): The node seeds for which neighbors are sampled to create mini-batches. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node seeds. + Must *not* be provided when ``node_sampler_input`` is provided. + node_sampler_input: (AllNodeSamplerInputs, optional): The input for node sampler. + May be used to override the functionality when sampling. + Must *not* be provided when ``input_nodes`` is provided. batch_size (int): How many samples per batch to load (default: ``1``). shuffle (bool): Set to ``True`` to have the data reshuffled at every epoch (default: ``False``). @@ -74,7 +84,8 @@ class DistNeighborLoader(DistLoader): def __init__(self, data: Optional[DistDataset], num_neighbors: NumNeighbors, - input_nodes: InputNodes, + input_nodes: Optional[InputNodes] = None, + node_sampler_input: Optional[AllNodeSamplerInputs] = None, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, @@ -85,13 +96,19 @@ def __init__(self, to_device: Optional[torch.device] = None, random_seed: Optional[int] = None, worker_options: Optional[AllDistSamplingWorkerOptions] = None): - + if (input_nodes is None) == (node_sampler_input is None): + raise ValueError( + f"Exactly one of input_nodes or node_sampler_input must be provided. Received input_nodes: {type(input_nodes)}, node_sampler_input: {type(node_sampler_input)}" + ) if isinstance(input_nodes, tuple): input_type, input_seeds = input_nodes - else: + elif input_nodes is not None: input_type, input_seeds = None, input_nodes - if isinstance(worker_options, RemoteDistSamplingWorkerOptions): + if node_sampler_input is not None: + _check_input_type(node_sampler_input, worker_options) + input_data = node_sampler_input + elif isinstance(worker_options, RemoteDistSamplingWorkerOptions): if isinstance(input_seeds, Split): input_data = RemoteNodeSplitSamplerInput(split=input_seeds, input_type=input_type) if isinstance(worker_options.server_rank, List): @@ -113,6 +130,29 @@ def __init__(self, with_weight=with_weight, edge_dir=edge_dir, seed=random_seed ) - super().__init__( - data, input_data, sampling_config, to_device, worker_options - ) + super().__init__(data, input_data, sampling_config, to_device, worker_options) + + +def _check_input_type(node_sampler_input: AllNodeSamplerInputs, worker_options: AllDistSamplingWorkerOptions): + if isinstance(worker_options, RemoteDistSamplingWorkerOptions): + if isinstance(worker_options.server_rank, List): + if not isinstance(node_sampler_input, List): + raise ValueError( + f"When worker options is {type(RemoteDistSamplingWorkerOptions)} with server_rank as List, node_sampler_input must be a list, but got {type(node_sampler_input)}" + ) + if not all( + isinstance(sampler_input, RemoteNodeSplitSamplerInput) + for sampler_input in node_sampler_input + ): + raise ValueError( + f"When worker options is RemoteDistSamplingWorkerOptions with server_rank as List, node_sampler_input must be a list and all elements of the list must be of type RemoteNodePathSamplerInput, but got {[type(nsi) for nsi in node_sampler_input]}" + ) + elif isinstance(node_sampler_input, List): + if not all(isinstance(sampler_input, RemoteNodePathSamplerInput) for sampler_input in node_sampler_input): + raise ValueError( + f"When worker options is {type(worker_options)} and node_sampler_input is a list, all elements of the list must be of type RemoteNodePathSamplerInput, but got {[type(nsi) for nsi in node_sampler_input]}" + ) + elif not isinstance(node_sampler_input, RemoteNodePathSamplerInput): + raise ValueError( + f"When worker options is {type(worker_options)} and node_sampler_input is not a list, node_sampler_input must be of type RemoteNodePathSamplerInput, but got {type(node_sampler_input)}" + ) \ No newline at end of file diff --git a/graphlearn_torch/python/sampler/base.py b/graphlearn_torch/python/sampler/base.py index 5b831b3b..e416085e 100644 --- a/graphlearn_torch/python/sampler/base.py +++ b/graphlearn_torch/python/sampler/base.py @@ -24,6 +24,8 @@ from ..typing import NodeType, EdgeType, NumNeighbors, Split from ..utils import CastMixin +AllNodeSamplerInputsBase = Union['NodeSamplerInput', 'RemoteNodeSplitSamplerInput', 'RemoteNodePathSamplerInput'] +AllNodeSamplerInputs = Union[AllNodeSamplerInputsBase, List['RemoteNodePathSamplerInput'], List['RemoteNodeSplitSamplerInput']] class EdgeIndex(NamedTuple): r""" PyG's :class:`~torch_geometric.loader.EdgeIndex` used in old data loader