Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotating functions that remove a specified dimenion #298

Open
jonasjuerss opened this issue Feb 17, 2025 · 1 comment
Open

Annotating functions that remove a specified dimenion #298

jonasjuerss opened this issue Feb 17, 2025 · 1 comment

Comments

@jonasjuerss
Copy link

jonasjuerss commented Feb 17, 2025

Hi, I was wondering if there is a nice way to annotate functions that operate along a given dimension, removing said dimension in the process? Examples for such functions would be sum(), mean(), argmax(), cumsum() etc. I would imagine something along the lines of:

@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def my_mean(x: Float[Tensor, " *dims"], dim: int) -> Float[Tensor, "*dims[:{dim-1}] *dims[{dim+1}:]"]:
    return torch.mean(x, dim=dim)

where the * in the return value denotes unpacking rather than matching an arbitrary number of dimensions.

This may be vaguely related to #184 if one could do something like

@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def my_mean(x: Float[Tensor, " *A my_dim *B"], dim: int) -> Float[Tensor, "*A *B"]:
    return torch.mean(x, dim=dim)

to provide an at least partially useful hint, even though this wouldn't allow specifying which dimension is removed.

@patrick-kidger
Copy link
Owner

This is a really nice example of something that's both useful, and that we totally don't support at the moment :D

Unfortunately I think adding something like this is pretty hard. I'd be open to taking a PR that does something like this, I think, although the symbolic resolution handling in jaxtyping is fairly finickity. FWIW on syntax I think we'd look for something like Float[Tensor, "{dims[:dim]} {dims[dim+1:]}"], for consistency with how our f-strings already work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants