Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mesh cnn #2293

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open

Mesh cnn #2293

wants to merge 27 commits into from

Conversation

Omri-L
Copy link

@Omri-L Omri-L commented Mar 22, 2021

Full implementation of MeshCNN paper using https://github.com/ranahanocka/MeshCNN (with full coordination and instruction of the author).
The implementation includes examples of training classification or segmentation tasks and also datasets used in the paper.

@math-araujo
Copy link

Hey @Omri-L . Does your implementation fix this issue of the original repo? Thanks in advance!

@Omri-L
Copy link
Author

Omri-L commented Mar 27, 2021

Hey @Omri-L . Does your implementation fix this issue of the original repo? Thanks in advance!

Hi,
Unfortunately, no. I did not have an access to multiple GPUs while integrating it. I will try to add it to my todo list.

@ranahanocka
Copy link

@rusty1s let us know if there's something missing / else that should be done here!

@rusty1s
Copy link
Member

rusty1s commented Mar 30, 2021

Thank you very much for this PR. Since this is a big one, is there any chance we can split this into smaller PRs? I'm thinking of separate PRs for datasets, example scripts, operators, and transforms. Let me know if this is possible. It would help me tremendously for merging :)

@Omri-L
Copy link
Author

Omri-L commented Mar 31, 2021

Thank you very much for this PR. Since this is a big one, is there any chance we can split this into smaller PRs? I'm thinking of separate PRs for datasets, example scripts, operators, and transforms. Let me know if this is possible. It would help me tremendously for merging :)

Hi @rusty1s ,
Yes of course we can split it to different PRs. How do you want me to do that?
Should I create new and separate PRs?
For example:

  1. Data structures + transforms
  2. Datasets
  3. Operators (e.g., MeshConv, MeshPool, MeshUnPool)
  4. Example scripts (e.g., MeshCNN models and MeshCNN trainers)

Sounds right?

@rusty1s
Copy link
Member

rusty1s commented Mar 31, 2021

Sounds great :)

@Omri-L
Copy link
Author

Omri-L commented Apr 3, 2021

Hi @rusty1s ,
I've revert the non data-structures and transforms files and changes from my last commit.
The new commit is 8c786c5.
Hope that this is what you meant for. If you prefer that I will close this PR and open new one just tell me.

Thanks,
Omri.

@rusty1s
Copy link
Member

rusty1s commented Apr 4, 2021

Thanks. This looks good to me. I will have a closer look in the upcoming week.

@rusty1s rusty1s self-requested a review May 17, 2021 09:38
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Omri-L and @ranahanocka and sorry for my delayed review.

Thanks once again for this PR. I am really happy to see this effort of merging MeshCNN into PyTorch Geometric.

This is still a pretty complex PR for me to merge, in particular, because I feel the code is really tight to the specific MeshCNN use-case. I would love to see this fit into PyG a little bit better.

For example, I think we can directly make use of torch_geometric.data.Data to hold mesh data, and convert a given mesh (e.g. given by FAUST) to the required MeshCNN representation via a single transform. In particular, I'm thinking of the following interface:

dataset = FAUST(..., pre_transform=MeshCNNPrepare())

which converts a mesh data object, i.e., Data(pos=[num_nodes, 3], face=[3, num_faces]) into a variant that can be processed by MeshCNN, i.e.:

Data(edge_index=[2, num_edges], edge_attr=[num_edges, 5], neighbor_e_id=[num_edges, 4], ...)

where neighbor_e_id denotes the indices of adjacent edges. Please let me know if this is possible.

import torch


def get_meshes_edge_index(meshes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this function or can we implement MeshCNN without padding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but this padding is not really used in this PR (will be in the next PR after I will resolve the issues here..). We need to pad the edge indexes since in some datasets the number of input edges could be different for each mesh example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Omri-L , do you think we can drop files that are not used in this PR. should make it easier to review.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @wsad1 , In this updated PR you have all the files that are needed. My comment from June 4th was regarding a smaller PR, but here you have it all and need all the files.
Thanks!

(maximum 2 faces), normalized by the total
area of all the faces. Used for
segmentation accuracy.
edge_index (np.array(2, Num_edges x 5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we hold the mesh graph data in edge_index (just as in standard PyG), and have an additional index matrix neighbor_e_id of shape [num_edges, 4] that holds the indices of direct neighbors for every edge?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand. I was followed by the introduction example:

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)

Can you please elaborate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstood the code, but my impression is that you have a regular edge_index tensor of shape [2, num_edges] denoting mesh connections. For MeshCNN, we need to collect the direct neighboring edges, e.g., 4 edges in case of two neighboring faces. I think one can just create a tensor of shape [num_edges, 4] that hold the indices of edges denoted in edge_index that are attached to the specific edge. WDYT?

Semantic changes
@Omri-L
Copy link
Author

Omri-L commented Jun 4, 2021

Hi @Omri-L and @ranahanocka and sorry for my delayed review.

Thanks once again for this PR. I am really happy to see this effort of merging MeshCNN into PyTorch Geometric.

This is still a pretty complex PR for me to merge, in particular, because I feel the code is really tight to the specific MeshCNN use-case. I would love to see this fit into PyG a little bit better.

For example, I think we can directly make use of torch_geometric.data.Data to hold mesh data, and convert a given mesh (e.g. given by FAUST) to the required MeshCNN representation via a single transform. In particular, I'm thinking of the following interface:

dataset = FAUST(..., pre_transform=MeshCNNPrepare())

which converts a mesh data object, i.e., Data(pos=[num_nodes, 3], face=[3, num_faces]) into a variant that can be processed by MeshCNN, i.e.:

Data(edge_index=[2, num_edges], edge_attr=[num_edges, 5], neighbor_e_id=[num_edges, 4], ...)

where neighbor_e_id denotes the indices of adjacent edges. Please let me know if this is possible.

Hi @rusty1s ,
Sorry for my late reply. I will study your comments and try my best.
Not sure what FAUST is and how I can use it. Can you please elaborate?

@rusty1s
Copy link
Member

rusty1s commented Jun 5, 2021

Thanks for digging into this :) Appreciate it a lot.

The FAUST dataset is just an example of a mesh dataset already integrated in PyG. One can use others as well, such as ModelNet.

@Omri-L
Copy link
Author

Omri-L commented Sep 8, 2021

Hi @rusty1s ,
I've updated the pull request and solve the issues we have talked about in this PR (comments above).
Below is a summary for the PR. I know this is a big PR, let me know if something is unclear or whatever I can fix.
I suggest to start with sections 1 + 2 + 7, after that with 3, then 4 + 5 + 6 + 9 and finish with 8 + 10.

1. torch_geometric.data
Added files:

  • meshcnndatahandler.py - includes MeshData class and MeshCNNHandler class.
    MeshData - uses torch_geoemetric.data.Data and other attributes for the MeshCNN structure. I've changed the relevant attributes to be torch.tensor instead of np.array.
    I didn't use your suggestion of:
    Data(edge_index=[2, num_edges], edge_attr=[num_edges, 5], neighbor_e_id=[num_edges, 4], ...)
    (no neighbor_e_id) - I can do it but I worked hard to remove this kind of data structure from the original MeshCNN repository since it not fits the torch_geometric.data.Data structure.

MeshCNNHandler - Holds mesh data and implements other MeshCNN methods on this data structure.

2. torch_geometric.transforms
Add files:

  • mesh_prepare.py - includes the MeshCNNPrepare class with changed interface. Now one can use this class as a transform in order to convert different data structure (e.g. pos tensor and faces tensor) to MeshCNN data structure. See an example in "examples/meshcnn_faust.py" (as you have asked )
  • mesh_augmentations.py - includes all the augmentations on the Mesh structure used in MeshCNN paper.

3. torch_geometric.datasets
Added files:

  • mesh_cnn_base_datasets.py - includes MeshCnnBaseDataset, MeshCnnClassificationDataset, MeshCnnSegmentationDataset classes - these are basic classes holds different kind of MeshCNN datasets
  • mesh_cnn_dataloader.py - includes MeshDataLoader class.
  • mesh_cnn_datasets.py - includes MeshShrech16Dataset, MeshCubesDataset, MeshHumanSegDataset, MeshCoSegDataset classes, which are the datasets used in the MeshCNN paper.

4. torch_geometric.nn.conv
Added file: mesh_conv.py - implements MeshConv operation with MessagePassing.

5. torch_geometric.nn.pool
Added file: mesh_pool.py - implements MeshPool operation.

6. torch_geometric.nn.unpool
Added file: mesh_unpool.py - implements MeshUnpool operation.

7. torch_geometric.utils
Added file: mesh_features.py - includes the implementation of MeshCNN geometric features.

8. torch_geometric.nn.models
Added file: mesh_models.py - implements some different MeshCNN models which use the mesh operations.

9. torch_geometric.nn.mesh
Added file: mesh_edge_padding.py - I've changed the padding value from 0 to -1. The MeshCNN operations should support this kind of padding.

10. examples
Added files:

  • mesh_cnn_classification.py
  • mesh_cnn_segmentation.py
    The two files are examples of how to use MeshCNN in this repository, implements main script for training and testing.
  • meshcnn_faust.py - another example of how to use MeshCNNPrepare on general mesh structure.

@rusty1s
Copy link
Member

rusty1s commented Sep 9, 2021

Thank you! I will have a closer look in the upcoming week and will let you know.

@codecov-commenter
Copy link

codecov-commenter commented Sep 20, 2021

Codecov Report

Merging #2293 (57f2f33) into master (07e273a) will decrease coverage by 3.50%.
The diff coverage is 22.52%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2293      +/-   ##
==========================================
- Coverage   73.56%   70.05%   -3.51%     
==========================================
  Files         287      300      +13     
  Lines       13719    15252    +1533     
==========================================
+ Hits        10092    10685     +593     
- Misses       3627     4567     +940     
Impacted Files Coverage Δ
torch_geometric/transforms/mesh_augmentations.py 10.30% <10.30%> (ø)
torch_geometric/utils/mesh_utils.py 11.76% <11.76%> (ø)
torch_geometric/transforms/mesh_prepare.py 15.47% <15.47%> (ø)
torch_geometric/nn/models/mesh_models.py 15.60% <15.60%> (ø)
torch_geometric/utils/mesh_features.py 19.23% <19.23%> (ø)
torch_geometric/nn/pool/mesh_pool.py 20.45% <20.45%> (ø)
torch_geometric/data/meshcnnhandler.py 21.80% <21.80%> (ø)
torch_geometric/nn/unpool/mesh_unpool.py 23.52% <23.52%> (ø)
torch_geometric/nn/mesh/mesh_edge_padding.py 25.00% <25.00%> (ø)
torch_geometric/nn/mesh/mesh_union.py 31.42% <31.42%> (ø)
... and 40 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 07e273a...57f2f33. Read the comment docs.

@rusty1s rusty1s requested a review from wsad1 September 21, 2021 18:32
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Omri-L. Really appreciate the changes to allow for an interface as described in meshcnn_faust.py. I went over datasets and loaders, and think we can polish it a bit further. In general, it would be great to follow the interfaces of other datasets, so that they can be re-used for other models as well besides MeshCNN, i.e., I would prefer to move all MeshCNNPrepare calls outside the datasets.

Let me know if you need any help in doing so.

root (str): Root folder for dataset.
dataset_url (str): dataset URL link (see supported links in
MeshCnnBaseDataset)
n_input_edges (int): Number of input edges of mesh. It should be the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure why we need this? Can you clarify?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which one?
n_input_edges is there since MeshCNN must get the same number of edges in a batch.
Therefore, n_input_edges is the maximal number of edges of a mesh object in the dataset. Any other mesh object with smaller number of input edges will be pad.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does MeshCNN need a fixed number of edges? This seems to be more of a design decision than an actual limitation of MeshCNN. Am I correct?

phase (str, optional): Train or Test phase - if 'train' it will return
the train dataset, if 'test' it will return the
test dataset. Default is 'train'.
num_aug (int, optional): Number of augmentations to apply on mesh.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally feel like these should all be handled by transforms, similar to other datasets in PyG. We should aim make the interface consistent (transform, pre_transform).

prefix, np.random.randint(0, self.num_aug)))

if self.num_aug > 1:
mesh_prepare = MeshCNNPrepare(aug_slide_verts=self.slide_verts,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specific MeshCNNPrepare transforms should be moved outside the datasets.

return mesh_data


class MeshCnnSegmentationDataset(MeshCnnBaseDataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments in MeshCNNClassificationDataset.

from torch_geometric.data.meshcnnhandler import MeshCNNHandler


class MeshDataLoader:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What prevents us from directly using the torch_geometric.data.DataLoader (as you already do in meshcnn_faust.py)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using different iter function. Do you have a suggestion how to avoid creating this class?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I see, the __iter__ function already re-implements a lot of the functionality from the PyG DataLoader, and most of its complexity stems from converting MeshCNNData via MeshCNNDataHandler. Can you give some insights on the separation of MeshCNNData and MeshCNNDataHandler, and why we need it in the first place? What happens if we convert meshes via MeshCNNDataHandler before data loading?

MeshCnnClassificationDataset, MeshCnnSegmentationDataset)


class MeshShrech16Dataset:
Copy link
Member

@rusty1s rusty1s Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably clean this up to avoid a lot of duplicated code. What about only providing MeshCNNClassificationDataset and MeshCNNSegmentationDataset in which you can decide on which dataset to use, e.g.:

dataset = MeshCNNClassificationDataset(root, name='shrech16',
                                       train=True, transform=MeshCNNPrepare(...))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right, but I meant to create base classes such as MeshCNNClassificationdataset and a specific classes for specific datasets (such as shrec16, cubes...). Do you think I should remove it? Please see the fix I did in my last commits.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this works as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MeshShrech16Dataset:
class MeshShrech16Dataset(MeshCNNClassificationDataset):

Making use of inheritance should result in cleaner and more readable code. WDYT?

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work. I have a few comments.

import torch


def get_meshes_edge_index(meshes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Omri-L , do you think we can drop files that are not used in this PR. should make it easier to review.

@Omri-L
Copy link
Author

Omri-L commented Sep 25, 2021

Hi @rusty1s and @wsad1
Thank you again for you time on this PR.
I solved many of the issues you mentioned above. Please let me know if I can do more.

Thanks.

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Omri-L thanks for being patient and for the updates. As this is a big PR, I think we'll have a few more rounds of reviews. Here are some comments.

Comment on lines +37 to +41
for i in range(self.neighborhood_size):
setattr(self, 'Linear{}'.format(i),
Linear(in_features=self.in_channels,
out_features=out_channels,
bias=self.bias))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can use a torch.nn.ModuleList here, as its a more standard way of holding a list model parameters. Something like.
self.linears = nn.ModuleList([nn.Linear(self.in_channels, self.out_channels,bias) for i in range(self.neighborhood_size)]).

self.batch_flatten_edge_index(edge_index,
num_features,
num_features_with_padding,
edge_features.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be edge_index.device?

Suggested change
edge_features.device)
edge_index.device)

edge_features = edge_features.squeeze(-1)
batch_size, feature_dim, num_features = edge_features.shape

edge_features = self.batch_flatten_features(edge_features)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the flattened edge_features be passed onto forward, that way this step could be avoided every forward pass.

from torch_geometric.data import Data


class MeshCNNData(Data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this class needed. Could the Data class be used instead.
MeshCnnData = Data(self.pos = pos,self.edges = edges,self.sides = sides,self.edges_count = edges_count,self.ve = ve.... What do you think.

self.n_input_channels = None
self.size = None
self.n_classes = None
super(MeshCnnBaseDataset, self).__init__(self.root)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you never actually load the data into memory:

self.data, self.slices = torch.load(self.processed_paths[0])

See https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html#creating-in-memory-datasets

continue
else:
data = self.get_data_from_file(path)
self.mesh_prepare_transform(data, load_file)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to collate and save the processed files here.

def len(self):
return len(self.paths_labels)

def get(self, idx):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be dropped once the dataset class follows the InMemoryDataset interface.

@KevinCrp
Copy link

Hi, any update concerning MeshCNN?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants