Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM][NFC] GpuBlaslt matmul thunk cache refactoring part II #23315

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pemeliya
Copy link
Contributor

@pemeliya pemeliya commented Mar 3, 2025

After this PR is merged: #21886,
GpuBlasLt MatmulPlan cache can now be refactored.
It was originally introduced here: #6595

Here the idea is that multiple matmul thunks can share the same matmul plans (allocated on the same device).
This can significantly reduce the memory overhead for large training. Furthermore, the original cache used stream pointer as a cache key: this might be inefficient when banchmarking the same HLO using XLA tools like multi_host_hlo_runner (which could allocate a new stream for each iteration).

I have also added the correspnding gpublas_lt_matmul_thunk_test to check this functionality.
Besides, I also refactored CublasLtCmd which was a blank copy-paste of CublasLtMatmulThunk.

Finally, I have removed the magic constant 128 and replaced it with GemmConfig::kNumAlgorithms.

@xla-rotation could you please have a look ?

I have also gathered some stats for LLAMA Maxtext model training with 8 GPUs:

2025-03-05 13:33:04.446763: I external/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc:131] Total matmul thunks created: 1039
................
2025-03-05 13:34:48.478186: I external/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc:188] 0x558aeca66fd8: Adding new MatmulPlan for stream: 0x558a9ee83a40
2025-03-05 13:34:48.478220: I external/xla/xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.cc:67] Plan created: cache size: 29

So, XLA runtime created 1039 GpuBlasLt Thunk instances while cache contained only 29 unique entries (because most of Gemm configs are duplicates). Hence, provided that it runs with 8 GPUs, we create at least (1039 - 29) * 8 = 8080 matmul plans less than before.

Comment on lines 55 to 56
absl::StatusOr<se::gpu::BlasLt::MatmulPlan *>
GetOrCreate(const std::string& key, Func&& create) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is dangerous because it's really easy to poison the cache.

The problem is that the user of this function can't rely on the fact that the returned MatmulPlan is the same as what create() would have produced. It might be true right now, but it's brittle. If someone changed the implementation of create and we run two instance of XLA in the same process (which some users do), then this can lead to very subtle bugs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether it's too expensive to always create a MatmulPlan and then keep a list of unique plans in the cache. This would avoid the need for a key as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the idea was here that only GpuBlasLtMatmulThunk can access this cache and that the cache keys are generated using AutotuneCacheKey() constructor: i.e., the one we also for cache keys during GEMM tuning (gemm_algorithm_picker.cc). So I assumed that it is safe to use it in this context: otherwise our autotune algorithm ID would not match to a MatmulPlan which would cause a runtime error.

Yes, potentially one can "precreate" this cache during XLA compilation by going through all computations and extracting all cublas$lt$matmul custom calls (e.g. in some extra HLO pass). But we would have to do this for all GpuExecutors created by XLA since MatmulPlanCache is bound to a particular device.

Or maybe I do not quite understand what you mean by always creating a MatmulPlan ?

Comment on lines 49 to 50
auto& res = meta[device_id];
if (!res) res.reset(new MatmulPlanCache());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about different sessions/contexts on the same GPU? They can't share the plan as far as I know, so this might break things for users that run multiple instances of XLA in the same process.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean previously we have used Stream pointer as a cache key. Each stream is bound to a particular StreamExecutor which I assume is unique for one device (but we can also use a pointer to StreamExecutor as a meta-cache key instead of a device ordinal to make it more safer). MatmulPlans are context-agnostic - the context activation occurs when fetching the algorithms and before the actual Matmul.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it more safe, I have changed the cache interface to use a pointer to cache executor as a key. Therefore, if a user runs several XLA instances within the same process, each of them will create a different StreamExecutor, and hence MatmulPlan caches won't be shared even if they point to the same device ordinal. Though, I assume, most likely, each XLA instance will be bound to a different device (which sounds like a more natural setting).

@pemeliya pemeliya force-pushed the ci_gpublas_lt_cache_refactor2 branch 2 times, most recently from eb93eb1 to 28f50f8 Compare March 12, 2025 14:43
@pemeliya pemeliya requested a review from beckerhe March 12, 2025 17:09
@pemeliya pemeliya force-pushed the ci_gpublas_lt_cache_refactor2 branch from 28f50f8 to 49ee5c1 Compare March 20, 2025 15:55
@pemeliya pemeliya force-pushed the ci_gpublas_lt_cache_refactor2 branch from 49ee5c1 to 1997f00 Compare April 2, 2025 09:11
@pemeliya
Copy link
Contributor Author

pemeliya commented Apr 2, 2025

@beckerhe, I wonder what needs to be done more for this PR ? Is it still pending internal review ? As I wrote, I wrote extra unit tests for gpublaslt_matmul_thunk and matmul_plan_cache

adapted cublaslt

added test

enabled tests on rocm

update

restored command_buffer thunk tests (to be updated in a new PR)

moved matmul plan cache to a separate file

updated matmul_thunk test

updated thunk test

added a separate matmul_plan cache unit test

made SetAlgorithm function non-const
@pemeliya pemeliya force-pushed the ci_gpublas_lt_cache_refactor2 branch from 1997f00 to 6ca1c6d Compare April 2, 2025 15:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants