Skip to content

Commit d08b600

Browse files
committed
Fix GPU handle pool singleton aliasing with tag-based template separation
Problem: - BLAS and SOLVER handle pools were sharing the same static singleton instance - C++ template instantiation created only one pool when HandleType/StreamType were identical - This caused resource contamination between hipBLAS and hipSOLVER operations - Led to potential memory corruption and unexpected GPU library behavior Solution: - Added Tag template parameter to HandlePool class with default DefaultTag - Introduced BlasTag and SolverTag for unique pool instantiations - Updated all handle pool typedefs to use three-argument template - Each tag type now gets its own static singleton pool instance
1 parent 3d6b521 commit d08b600

File tree

5 files changed

+29
-18
lines changed

5 files changed

+29
-18
lines changed

jaxlib/gpu/blas_handle_pool.cc

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ template <>
4141
return Handle(pool, handle, stream);
4242
}
4343

44-
} // namespace jax
44+
} // namespace jax

jaxlib/gpu/blas_handle_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222

2323
namespace jax {
2424

25-
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
25+
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>;
2626

2727
template <>
2828
absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(

jaxlib/gpu/handle_pool.h

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,20 @@ limitations under the License.
2525

2626
namespace jax {
2727

28+
// Tag types for unique pool instantiations
29+
struct DefaultTag {};
30+
struct BlasTag {};
31+
struct SolverTag {};
32+
2833
// To avoid creating cublas/cusolver contexts in the middle of execution, we
2934
// maintain a pool of them.
30-
template <typename HandleType, typename StreamType>
35+
36+
// The Tag template parameter ensures unique pool instantiations for different
37+
// handle types (BLAS, SOLVER, etc.). Without this tag, C++ template
38+
// instantiation would create a single shared static pool when HandleType and
39+
// StreamType are the same, leading to resource contamination between different
40+
// GPU library contexts (e.g., hipBLAS and hipSOLVER sharing the same pool).
41+
template <typename HandleType, typename StreamType, typename Tag = DefaultTag>
3142
class HandlePool {
3243
public:
3344
HandlePool() = default;
@@ -66,11 +77,11 @@ class HandlePool {
6677
HandleType get() { return handle_; }
6778

6879
private:
69-
friend class HandlePool<HandleType, StreamType>;
70-
Handle(HandlePool<HandleType, StreamType>* pool, HandleType handle,
80+
friend class HandlePool<HandleType, StreamType, Tag>;
81+
Handle(HandlePool<HandleType, StreamType, Tag>* pool, HandleType handle,
7182
StreamType stream)
7283
: pool_(pool), handle_(handle), stream_(stream) {}
73-
HandlePool<HandleType, StreamType>* pool_ = nullptr;
84+
HandlePool<HandleType, StreamType, Tag>* pool_ = nullptr;
7485
HandleType handle_ = nullptr;
7586
StreamType stream_ = nullptr;
7687
};
@@ -80,31 +91,32 @@ class HandlePool {
8091
static absl::StatusOr<Handle> Borrow(StreamType stream);
8192

8293
private:
83-
static HandlePool<HandleType, StreamType>* Instance();
94+
static HandlePool<HandleType, StreamType, Tag>* Instance();
8495

8596
void Return(HandleType handle, StreamType stream);
8697

8798
absl::Mutex mu_;
8899
std::map<StreamType, std::vector<HandleType>> handles_ ABSL_GUARDED_BY(mu_);
89100
};
90101

91-
template <typename HandleType, typename StreamType>
92-
/*static*/ HandlePool<HandleType, StreamType>*
93-
HandlePool<HandleType, StreamType>::Instance() {
94-
static auto* pool = new HandlePool<HandleType, StreamType>;
102+
template <typename HandleType, typename StreamType, typename Tag>
103+
/*static*/ HandlePool<HandleType, StreamType, Tag>*
104+
HandlePool<HandleType, StreamType, Tag>::Instance() {
105+
static auto* pool = new HandlePool<HandleType, StreamType, Tag>;
95106
return pool;
96107
}
97108

98-
template <typename HandleType, typename StreamType>
99-
void HandlePool<HandleType, StreamType>::Return(HandleType handle,
100-
StreamType stream) {
109+
template <typename HandleType, typename StreamType, typename Tag>
110+
void HandlePool<HandleType, StreamType, Tag>::Return(HandleType handle,
111+
StreamType stream) {
101112
absl::MutexLock lock(&mu_);
102113
handles_[stream].push_back(handle);
103114
}
104115

105116
// template <typename HandleType, typename StreamType>
106117
// HandlePool<HandleType, StreamType>::Borrow(StreamType stream)
107118

119+
108120
} // namespace jax
109121

110122
#endif // JAXLIB_GPU_HANDLE_POOL_H_

jaxlib/gpu/solver_handle_pool.cc

100644100755
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "jaxlib/gpu/solver_handle_pool.h"
17-
1817
#include "absl/status/statusor.h"
1918
#include "absl/synchronization/mutex.h"
2019
#include "jaxlib/gpu/gpu_kernel_helpers.h"
@@ -40,7 +39,7 @@ template <>
4039
pool->handles_[stream].pop_back();
4140
}
4241
if (stream) {
43-
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
42+
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
4443
}
4544
return Handle(pool, handle, stream);
4645
}

jaxlib/gpu/solver_handle_pool.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ limitations under the License.
2626

2727
namespace jax {
2828

29-
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
29+
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t, SolverTag>;
3030

3131
template <>
3232
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
3333
gpuStream_t stream);
3434

3535
#ifdef JAX_GPU_CUDA
36-
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t>;
36+
using SpSolverHandlePool = HandlePool<cusolverSpHandle_t, gpuStream_t, SolverTag>;
3737

3838
template <>
3939
absl::StatusOr<SpSolverHandlePool::Handle> SpSolverHandlePool::Borrow(

0 commit comments

Comments
 (0)