Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust sharding rules for ragged_dot operations. #23956

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Mar 20, 2025

Adjust sharding rules for ragged_dot operations.

Mode 3 (the ragged dimension is a batch dimension)

Do not propagate shardings to/from the group_sizes (the 3rd operand) since it is redundant. We will discard this tensor and convert ragged_dot into a standard dot in the partitioner.

Mode 1 (the ragged dimension is a lhs non-contracting dimension)

The ragged dimension and the group_size dimension needs full replication in the partitioner.

Mode 2 (the ragged dimension is a contracting dimension)

The ragged dimension and the group_size dimension needs full replication in the partitioner. We do not mark the ragged dimension as replicated although it is a contracting dimension.

@copybara-service copybara-service bot force-pushed the test_738587239 branch 2 times, most recently from b4ead77 to 1ccbf1a Compare March 20, 2025 02:48
# Mode 3 (the ragged dimension is a batch dimension)
Do not propagate shardings to/from the group_sizes (the 3rd operand) since it is redundant. We will discard this tensor and convert ragged_dot into a standard dot in the partitioner.

# Mode 1 (the ragged dimension is a lhs non-contracting dimension)
The ragged dimension and the group_size dimension needs full replication in the partitioner.

# Mode 2 (the ragged dimension is a contracting dimension)
The ragged dimension and the group_size dimension needs full replication in the partitioner. We do not mark the ragged dimension as replicated although it is a contracting dimension.

PiperOrigin-RevId: 738587239
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant