Skip to content

Enable getting global indices from SamplerOutput #10200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

zaristei
Copy link
Contributor

The current implementation of SamplerOutput returns local indices of node ids in the node list for the row,col, batch, orig_row, and orig_col fields. This leads to some confusion for using the edges/seed nodes retrieved from the sampler. This PR adds some helper methods that will derive the global indices for these fields given the information already stored within the SamplerOutput.
Unittests have been also written for both homogeneous and heterogeneous outputs.

Copy link
Contributor

@Kh4L Kh4L left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Copy link

codecov bot commented Apr 16, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 85.62%. Comparing base (c211214) to head (236650c).
Report is 57 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #10200      +/-   ##
==========================================
- Coverage   86.11%   85.62%   -0.50%     
==========================================
  Files         496      498       +2     
  Lines       33655    34188     +533     
==========================================
+ Hits        28981    29272     +291     
- Misses       4674     4916     +242     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@puririshi98 puririshi98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but ill let @akihironitta or @rusty1s review and merge this one

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose?

@@ -160,3 +160,48 @@ def remap_keys(
k if k in exclude else mapping.get(k, k): v
for k, v in inputs.items()
}


def local_to_global_node_idx(node_values: Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a helper function to do node_values[local_indices]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its helpful to refactor as a separate function, because otherwise it's not immediately obvious to readers of the code that global indexing is functionally equivalent to locally indexing into the node field.

return torch.index_select(node_values, dim=0, index=local_indices)


def global_to_local_node_idx(node_values: Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would that be useful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the downstream PR, I create a sampler that needs to merge together more than one SamplerOutput. This helper is useful for doing operations of that sort, while also being a exact reverse mapping of local_to_global_node_idx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zaristei
Copy link
Contributor Author

The purpose of these changes is to make it easier for users of SamplerOutput to access the global row/col ids, as well as global seed node ids from the sampler. Although the docstring indicates that the row and col fields are indices of nodes, it's not immediately obvious from that documentation that one can get the global indices by indexing into the nodes list, or that the same can be done for to get seed nodes from batch, for instance.

@zaristei zaristei force-pushed the zaristei/sample_output_global_index branch from ab5ae61 to 3b8b10e Compare April 22, 2025 00:16
@zaristei zaristei force-pushed the zaristei/sample_output_global_index branch from 2df71ab to e6cbdf2 Compare April 22, 2025 19:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants