Skip to content

Feature add new one hot function meeting multi-dimensions (ranks) #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jan 15, 2025

Conversation

tiruka
Copy link
Contributor

@tiruka tiruka commented Dec 15, 2024

Target of This Pull Request

First, I attempted to implement a one-hot operation for ONNX. However, I realized that the existing one-hot function did not meet the requirements and, in fact, did not support multidimensional inputs at all. As I explored solutions, including the ONNX specifications, Pytorch, Tensorflow, I concluded that it was necessary to implement a new one-hot function. This led to the creation of this implementation, which I am now submitting as a pull request.
(Pytorch also does not implement complet one hot function, though.)

Hope this will work for burn and community.

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Indirectly related to onnx issues #1714

Changes

Newly implemented one hot method for numeric tensor. The reason it should belong to numeric is the return value should be defined by on_value and off_value, not tensor itself. So, the output can take either types of int and float.
This function comprehensively covers all aspects defined by ONNX, including depth, on_value, off_value, and axis, and complies with the one-hot operator specifications introduced in ONNX version 11 and later. By developing this, I believe it becomes possible to handle multidimensional one-hot encoding while also providing a concise and efficient implementation of the ONNX operator. For these reasons, I deemed it essential to create this function.

I considered removing and updating the existing one-hot method, but I decided to take a more conservative approach by leaving the existing method as it is and adding a new one instead.

Testing

Adding tests on crates/burn-tensor/src/tests/ops/one_hot.rs and passing run-checks all.

Copy link

codecov bot commented Dec 15, 2024

Codecov Report

Attention: Patch coverage is 94.26230% with 7 lines in your changes missing coverage. Please review.

Project coverage is 83.19%. Comparing base (f1558ad) to head (4bffb30).
Report is 51 commits behind head on main.

Files with missing lines Patch % Lines
crates/burn-tensor/src/tensor/api/check.rs 53.33% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2613      +/-   ##
==========================================
+ Coverage   81.86%   83.19%   +1.33%     
==========================================
  Files         833      819      -14     
  Lines      106450   106801     +351     
==========================================
+ Hits        87146    88857    +1711     
+ Misses      19304    17944    -1360     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take some time to look at this later, but just a couple comments before reviewing the actual code.

  1. We don't need to have all of the ops comply with the ONNX spec.
  2. Introducing this as a new operation means that we now have multiple definitions for one-hot. One definition should take over, otherwise it makes everything cluttered.
  3. Regarding the motivation, do you actually need this one-hot definition? Or is it simply for ONNX conversion? If only the latter, than it can probably just live in the ONNX import code.

@tiruka
Copy link
Contributor Author

tiruka commented Dec 19, 2024

@laggui Thank you for your comments.

The existing one_hot function only operates on rank-1 tensors, which limits its usability.

impl<B> Tensor<B, 1, Int> {
  ...
  pub one_hot() {
  ...
  
  }
}  

For current float version one hot, I do not come up with any use case.

In PyTorch, for example, the function is minimally designed to support multiple dimensions, and this aspect is something that needs improvement in our framework as well.
Pytorch example

F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])

Furthermore, another major framework, TensorFlow, not only supports multiple dimensions but also provides flexibility with parameters such as axis and values.
Tensorflow example

indices = [[0, 2], [1, -1]]
depth = 3
tf.one_hot(indices, depth,
           on_value=1.0, off_value=0.0,
           axis=-1)  # output: [2 x 2 x 3]
# [[[1.0, 0.0, 0.0],   # one_hot(0)
#   [0.0, 0.0, 1.0]],  # one_hot(2)
#  [[0.0, 1.0, 0.0],   # one_hot(1)
#   [0.0, 0.0, 0.0]]]  # one_hot(-1)

Further usecases

The ability to configure multiple dimensions, axis, and values is an expected feature in popular frameworks, and I believe this would greatly benefit Burn users, myself included, by helping the framework stay aligned with modern expectations. This one hot function is not closed only for ONNX.

Regarding the concern about having multiple definitions, I also have the same sentiment and agree that unification is necessary. My proposed new function is designed to support both int and float types, making it closer to the one_hot definitions found in other frameworks. As such, I would advocate deprecating the existing implementation and unifying it with this new version. If there is agreement on this approach, I would be happy to submit changes either as part of this PR or in a separate PR to address these points.

I look forward to your feedback and hope for your support in making this improvement.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense! Thanks for the detailed response.

I think it is especially useful for the rank > 1 use cases, the rest of the configurable stuff seems less relevant to me. But I understand that there could be value in supporting the broad spec.

See my comments below 🙂

Regarding the multiple definitions, I think I would deprecate the other definitions since this can do it all. Just make sure to adapt the existing tests.

@tiruka
Copy link
Contributor Author

tiruka commented Dec 25, 2024

@laggui I modified codes, please review them again (maybe after your Christmas vacation, enjoy!).

@tiruka tiruka changed the title Feature add new one hot function meeting full requirements. Feature add new one hot function meeting multi-dimensions (ranks) Dec 26, 2024
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope you had a nice holiday break! Thanks for addressing my previous comments 🙂

I have some follow-up changes. Mostly form over content.

@tiruka
Copy link
Contributor Author

tiruka commented Jan 8, 2025

@laggui
I am back now and hank you for your review. I modified them.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright should be good to go after this round! 🙂

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Made a minor change to the tensor check message

@laggui laggui merged commit ad81344 into tracel-ai:main Jan 15, 2025
11 checks passed
@tiruka tiruka deleted the feature-add-new-one-hot-function branch January 15, 2025 21:35
@tiruka
Copy link
Contributor Author

tiruka commented Jan 15, 2025

@laggui Thank you for your much help! I really appreciate it to complete this pr. 👍

@laggui laggui mentioned this pull request Feb 7, 2025
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants