-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mesh cnn #2293
base: master
Are you sure you want to change the base?
Mesh cnn #2293
Conversation
@rusty1s let us know if there's something missing / else that should be done here! |
Thank you very much for this PR. Since this is a big one, is there any chance we can split this into smaller PRs? I'm thinking of separate PRs for datasets, example scripts, operators, and transforms. Let me know if this is possible. It would help me tremendously for merging :) |
Hi @rusty1s ,
Sounds right? |
Sounds great :) |
Thanks. This looks good to me. I will have a closer look in the upcoming week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Omri-L and @ranahanocka and sorry for my delayed review.
Thanks once again for this PR. I am really happy to see this effort of merging MeshCNN into PyTorch Geometric.
This is still a pretty complex PR for me to merge, in particular, because I feel the code is really tight to the specific MeshCNN use-case. I would love to see this fit into PyG a little bit better.
For example, I think we can directly make use of torch_geometric.data.Data
to hold mesh data, and convert a given mesh (e.g. given by FAUST
) to the required MeshCNN representation via a single transform
. In particular, I'm thinking of the following interface:
dataset = FAUST(..., pre_transform=MeshCNNPrepare())
which converts a mesh data object, i.e., Data(pos=[num_nodes, 3], face=[3, num_faces])
into a variant that can be processed by MeshCNN, i.e.:
Data(edge_index=[2, num_edges], edge_attr=[num_edges, 5], neighbor_e_id=[num_edges, 4], ...)
where neighbor_e_id
denotes the indices of adjacent edges. Please let me know if this is possible.
import torch | ||
|
||
|
||
def get_meshes_edge_index(meshes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this function or can we implement MeshCNN without padding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but this padding is not really used in this PR (will be in the next PR after I will resolve the issues here..). We need to pad the edge indexes since in some datasets the number of input edges could be different for each mesh example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Omri-L , do you think we can drop files that are not used in this PR. should make it easier to review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @wsad1 , In this updated PR you have all the files that are needed. My comment from June 4th was regarding a smaller PR, but here you have it all and need all the files.
Thanks!
torch_geometric/transforms/mesh.py
Outdated
(maximum 2 faces), normalized by the total | ||
area of all the faces. Used for | ||
segmentation accuracy. | ||
edge_index (np.array(2, Num_edges x 5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we hold the mesh graph data in edge_index
(just as in standard PyG), and have an additional index matrix neighbor_e_id
of shape [num_edges, 4]
that holds the indices of direct neighbors for every edge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand. I was followed by the introduction example:
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
Can you please elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I misunderstood the code, but my impression is that you have a regular edge_index
tensor of shape [2, num_edges]
denoting mesh connections. For MeshCNN, we need to collect the direct neighboring edges, e.g., 4 edges in case of two neighboring faces. I think one can just create a tensor of shape [num_edges, 4]
that hold the indices of edges denoted in edge_index
that are attached to the specific edge. WDYT?
Semantic changes
Hi @rusty1s , |
Thanks for digging into this :) Appreciate it a lot. The |
… Fixes in mesh pooling and mesh classes to support new data structure with tensors. 4. Update FAUST example with MeshCNN data structure.
Hi @rusty1s , 1. torch_geometric.data
MeshCNNHandler - Holds mesh data and implements other MeshCNN methods on this data structure. 2. torch_geometric.transforms
3. torch_geometric.datasets
4. torch_geometric.nn.conv 5. torch_geometric.nn.pool 6. torch_geometric.nn.unpool 7. torch_geometric.utils 8. torch_geometric.nn.models 9. torch_geometric.nn.mesh 10. examples
|
Thank you! I will have a closer look in the upcoming week and will let you know. |
Codecov Report
@@ Coverage Diff @@
## master #2293 +/- ##
==========================================
- Coverage 73.56% 70.05% -3.51%
==========================================
Files 287 300 +13
Lines 13719 15252 +1533
==========================================
+ Hits 10092 10685 +593
- Misses 3627 4567 +940
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Omri-L. Really appreciate the changes to allow for an interface as described in meshcnn_faust.py
. I went over datasets and loaders, and think we can polish it a bit further. In general, it would be great to follow the interfaces of other datasets, so that they can be re-used for other models as well besides MeshCNN
, i.e., I would prefer to move all MeshCNNPrepare
calls outside the datasets.
Let me know if you need any help in doing so.
root (str): Root folder for dataset. | ||
dataset_url (str): dataset URL link (see supported links in | ||
MeshCnnBaseDataset) | ||
n_input_edges (int): Number of input edges of mesh. It should be the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really sure why we need this? Can you clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which one?
n_input_edges is there since MeshCNN must get the same number of edges in a batch.
Therefore, n_input_edges is the maximal number of edges of a mesh object in the dataset. Any other mesh object with smaller number of input edges will be pad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does MeshCNN
need a fixed number of edges? This seems to be more of a design decision than an actual limitation of MeshCNN
. Am I correct?
phase (str, optional): Train or Test phase - if 'train' it will return | ||
the train dataset, if 'test' it will return the | ||
test dataset. Default is 'train'. | ||
num_aug (int, optional): Number of augmentations to apply on mesh. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally feel like these should all be handled by transforms, similar to other datasets in PyG. We should aim make the interface consistent (transform
, pre_transform
).
prefix, np.random.randint(0, self.num_aug))) | ||
|
||
if self.num_aug > 1: | ||
mesh_prepare = MeshCNNPrepare(aug_slide_verts=self.slide_verts, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specific MeshCNNPrepare
transforms should be moved outside the datasets.
return mesh_data | ||
|
||
|
||
class MeshCnnSegmentationDataset(MeshCnnBaseDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments in MeshCNNClassificationDataset
.
from torch_geometric.data.meshcnnhandler import MeshCNNHandler | ||
|
||
|
||
class MeshDataLoader: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What prevents us from directly using the torch_geometric.data.DataLoader
(as you already do in meshcnn_faust.py
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using different iter function. Do you have a suggestion how to avoid creating this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I see, the __iter__
function already re-implements a lot of the functionality from the PyG DataLoader
, and most of its complexity stems from converting MeshCNNData
via MeshCNNDataHandler
. Can you give some insights on the separation of MeshCNNData
and MeshCNNDataHandler
, and why we need it in the first place? What happens if we convert meshes via MeshCNNDataHandler
before data loading?
MeshCnnClassificationDataset, MeshCnnSegmentationDataset) | ||
|
||
|
||
class MeshShrech16Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably clean this up to avoid a lot of duplicated code. What about only providing MeshCNNClassificationDataset
and MeshCNNSegmentationDataset
in which you can decide on which dataset to use, e.g.:
dataset = MeshCNNClassificationDataset(root, name='shrech16',
train=True, transform=MeshCNNPrepare(...))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are right, but I meant to create base classes such as MeshCNNClassificationdataset and a specific classes for specific datasets (such as shrec16, cubes...). Do you think I should remove it? Please see the fix I did in my last commits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this works as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MeshShrech16Dataset: | |
class MeshShrech16Dataset(MeshCNNClassificationDataset): |
Making use of inheritance should result in cleaner and more readable code. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work. I have a few comments.
import torch | ||
|
||
|
||
def get_meshes_edge_index(meshes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Omri-L , do you think we can drop files that are not used in this PR. should make it easier to review.
MeshCNNBaseDataset(Dataset) --> InMemoryDataset Co-authored-by: Matthias Fey <[email protected]>
…set, MeshCnnSegmentationDataset and MeshCNNPrepare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Omri-L thanks for being patient and for the updates. As this is a big PR, I think we'll have a few more rounds of reviews. Here are some comments.
for i in range(self.neighborhood_size): | ||
setattr(self, 'Linear{}'.format(i), | ||
Linear(in_features=self.in_channels, | ||
out_features=out_channels, | ||
bias=self.bias)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we can use a torch.nn.ModuleList here, as its a more standard way of holding a list model parameters. Something like.
self.linears = nn.ModuleList([nn.Linear(self.in_channels, self.out_channels,bias) for i in range(self.neighborhood_size)])
.
self.batch_flatten_edge_index(edge_index, | ||
num_features, | ||
num_features_with_padding, | ||
edge_features.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be edge_index.device?
edge_features.device) | |
edge_index.device) |
edge_features = edge_features.squeeze(-1) | ||
batch_size, feature_dim, num_features = edge_features.shape | ||
|
||
edge_features = self.batch_flatten_features(edge_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the flattened edge_features
be passed onto forward, that way this step could be avoided every forward pass.
from torch_geometric.data import Data | ||
|
||
|
||
class MeshCNNData(Data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this class needed. Could the Data
class be used instead.
MeshCnnData = Data(self.pos = pos,self.edges = edges,self.sides = sides,self.edges_count = edges_count,self.ve = ve...
. What do you think.
self.n_input_channels = None | ||
self.size = None | ||
self.n_classes = None | ||
super(MeshCnnBaseDataset, self).__init__(self.root) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you never actually load the data
into memory:
self.data, self.slices = torch.load(self.processed_paths[0])
continue | ||
else: | ||
data = self.get_data_from_file(path) | ||
self.mesh_prepare_transform(data, load_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to collate and save the processed files here.
def len(self): | ||
return len(self.paths_labels) | ||
|
||
def get(self, idx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function can be dropped once the dataset class follows the InMemoryDataset
interface.
Hi, any update concerning MeshCNN? |
Full implementation of MeshCNN paper using https://github.com/ranahanocka/MeshCNN (with full coordination and instruction of the author).
The implementation includes examples of training classification or segmentation tasks and also datasets used in the paper.