-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Enable getting global indices from SamplerOutput #10200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Enable getting global indices from SamplerOutput #10200
Conversation
cc3127a
to
dce2fe6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #10200 +/- ##
==========================================
- Coverage 86.11% 85.62% -0.50%
==========================================
Files 496 498 +2
Lines 33655 34188 +533
==========================================
+ Hits 28981 29272 +291
- Misses 4674 4916 +242 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but ill let @akihironitta or @rusty1s review and merge this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose?
@@ -160,3 +160,48 @@ def remap_keys( | |||
k if k in exclude else mapping.get(k, k): v | |||
for k, v in inputs.items() | |||
} | |||
|
|||
|
|||
def local_to_global_node_idx(node_values: Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a helper function to do node_values[local_indices]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its helpful to refactor as a separate function, because otherwise it's not immediately obvious to readers of the code that global indexing is functionally equivalent to locally indexing into the node field.
return torch.index_select(node_values, dim=0, index=local_indices) | ||
|
||
|
||
def global_to_local_node_idx(node_values: Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would that be useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the downstream PR, I create a sampler that needs to merge together more than one SamplerOutput. This helper is useful for doing operations of that sort, while also being a exact reverse mapping of local_to_global_node_idx
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of these changes is to make it easier for users of |
ab5ae61
to
3b8b10e
Compare
2df71ab
to
e6cbdf2
Compare
The current implementation of SamplerOutput returns local indices of node ids in the
node
list for therow
,col
,batch
,orig_row
, andorig_col
fields. This leads to some confusion for using the edges/seed nodes retrieved from the sampler. This PR adds some helper methods that will derive the global indices for these fields given the information already stored within the SamplerOutput.Unittests have been also written for both homogeneous and heterogeneous outputs.