Skip to content

Allow custom NodeSamplerInput to be piped into DistNeighborLoader #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 51 additions & 11 deletions graphlearn_torch/python/distributed/dist_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@

import torch

from ..sampler import NodeSamplerInput, SamplingType, SamplingConfig, \
RemoteNodeSplitSamplerInput, RemoteNodePathSamplerInput
from ..sampler import (
AllNodeSamplerInputs,
NodeSamplerInput,
SamplingType,
SamplingConfig,
RemoteNodeSplitSamplerInput,
RemoteNodePathSamplerInput,
)
from ..typing import InputNodes, NumNeighbors, Split

from .dist_dataset import DistDataset
Expand All @@ -27,7 +33,7 @@


class DistNeighborLoader(DistLoader):
r""" A distributed loader that preform sampling from nodes.
r"""A distributed loader that preform sampling from nodes.

Args:
data (DistDataset, optional): The ``DistDataset`` object of a partition of
Expand All @@ -37,10 +43,14 @@ class DistNeighborLoader(DistLoader):
The number of neighbors to sample for each node in each iteration.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The node seeds for
input_nodes (torch.Tensor or Tuple[str, torch.Tensor], optional): The node seeds for
which neighbors are sampled to create mini-batches. In heterogeneous
graphs, needs to be passed as a tuple that holds the node type and
node seeds.
Must *not* be provided when ``node_sampler_input`` is provided.
node_sampler_input: (AllNodeSamplerInputs, optional): The input for node sampler.
May be used to override the functionality when sampling.
Must *not* be provided when ``input_nodes`` is provided.
batch_size (int): How many samples per batch to load (default: ``1``).
shuffle (bool): Set to ``True`` to have the data reshuffled at every
epoch (default: ``False``).
Expand Down Expand Up @@ -74,7 +84,8 @@ class DistNeighborLoader(DistLoader):
def __init__(self,
data: Optional[DistDataset],
num_neighbors: NumNeighbors,
input_nodes: InputNodes,
input_nodes: Optional[InputNodes] = None,
node_sampler_input: Optional[AllNodeSamplerInputs] = None,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
Expand All @@ -85,13 +96,19 @@ def __init__(self,
to_device: Optional[torch.device] = None,
random_seed: Optional[int] = None,
worker_options: Optional[AllDistSamplingWorkerOptions] = None):

if (input_nodes is None) == (node_sampler_input is None):
raise ValueError(
f"Exactly one of input_nodes or node_sampler_input must be provided. Received input_nodes: {type(input_nodes)}, node_sampler_input: {type(node_sampler_input)}"
)
if isinstance(input_nodes, tuple):
input_type, input_seeds = input_nodes
else:
elif input_nodes is not None:
input_type, input_seeds = None, input_nodes

if isinstance(worker_options, RemoteDistSamplingWorkerOptions):
if node_sampler_input is not None:
_check_input_type(node_sampler_input, worker_options)
input_data = node_sampler_input
elif isinstance(worker_options, RemoteDistSamplingWorkerOptions):
if isinstance(input_seeds, Split):
input_data = RemoteNodeSplitSamplerInput(split=input_seeds, input_type=input_type)
if isinstance(worker_options.server_rank, List):
Expand All @@ -113,6 +130,29 @@ def __init__(self,
with_weight=with_weight, edge_dir=edge_dir, seed=random_seed
)

super().__init__(
data, input_data, sampling_config, to_device, worker_options
)
super().__init__(data, input_data, sampling_config, to_device, worker_options)


def _check_input_type(node_sampler_input: AllNodeSamplerInputs, worker_options: AllDistSamplingWorkerOptions):
if isinstance(worker_options, RemoteDistSamplingWorkerOptions):
if isinstance(worker_options.server_rank, List):
if not isinstance(node_sampler_input, List):
raise ValueError(
f"When worker options is {type(RemoteDistSamplingWorkerOptions)} with server_rank as List, node_sampler_input must be a list, but got {type(node_sampler_input)}"
)
if not all(
isinstance(sampler_input, RemoteNodeSplitSamplerInput)
for sampler_input in node_sampler_input
):
raise ValueError(
f"When worker options is RemoteDistSamplingWorkerOptions with server_rank as List, node_sampler_input must be a list and all elements of the list must be of type RemoteNodePathSamplerInput, but got {[type(nsi) for nsi in node_sampler_input]}"
)
elif isinstance(node_sampler_input, List):
if not all(isinstance(sampler_input, RemoteNodePathSamplerInput) for sampler_input in node_sampler_input):
raise ValueError(
f"When worker options is {type(worker_options)} and node_sampler_input is a list, all elements of the list must be of type RemoteNodePathSamplerInput, but got {[type(nsi) for nsi in node_sampler_input]}"
)
elif not isinstance(node_sampler_input, RemoteNodePathSamplerInput):
raise ValueError(
f"When worker options is {type(worker_options)} and node_sampler_input is not a list, node_sampler_input must be of type RemoteNodePathSamplerInput, but got {type(node_sampler_input)}"
)
2 changes: 2 additions & 0 deletions graphlearn_torch/python/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from ..typing import NodeType, EdgeType, NumNeighbors, Split
from ..utils import CastMixin

AllNodeSamplerInputsBase = Union['NodeSamplerInput', 'RemoteNodeSplitSamplerInput', 'RemoteNodePathSamplerInput']
AllNodeSamplerInputs = Union[AllNodeSamplerInputsBase, List['RemoteNodePathSamplerInput'], List['RemoteNodeSplitSamplerInput']]

class EdgeIndex(NamedTuple):
r""" PyG's :class:`~torch_geometric.loader.EdgeIndex` used in old data loader
Expand Down
Loading