Skip to content

Commit 09f9ed0

Browse files
committed
Fix GPU handle pool singleton aliasing with tag-based template separation
Problem: - BLAS and SOLVER handle pools were sharing the same static singleton instance - C++ template instantiation created only one pool when HandleType/StreamType were identical - This caused resource contamination between hipBLAS and hipSOLVER operations - Led to potential memory corruption and unexpected GPU library behavior Solution: - Added Tag template parameter to HandlePool class with default DefaultTag - Introduced BlasTag and SolverTag for unique pool instantiations - Updated all handle pool typedefs to use three-argument template - Each tag type now gets its own static singleton pool instance
1 parent 3d6b521 commit 09f9ed0

File tree

5 files changed

+60
-22
lines changed

5 files changed

+60
-22
lines changed

jaxlib/gpu/blas_handle_pool.cc

100644100755
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ limitations under the License.
2424
namespace jax {
2525

2626
template <>
27-
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
28-
gpuStream_t stream) {
29-
BlasHandlePool* pool = Instance();
27+
/*static*/ absl::StatusOr<HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>::Handle>
28+
HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>::Borrow(gpuStream_t stream) {
29+
auto* pool = Instance(HandleKind::BLAS);
3030
absl::MutexLock lock(&pool->mu_);
3131
gpublasHandle_t handle;
3232
if (pool->handles_[stream].empty()) {
@@ -36,7 +36,11 @@ template <>
3636
pool->handles_[stream].pop_back();
3737
}
3838
if (stream) {
39-
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
39+
if (pool->kind_ == HandleKind::BLAS) {
40+
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
41+
} else {
42+
return absl::InternalError("BlasHandlePool kind is not BLAS");
43+
}
4044
}
4145
return Handle(pool, handle, stream);
4246
}

jaxlib/gpu/blas_handle_pool.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
// Tag type for BLAS pool
3+
struct BlasTag {};
14
/* Copyright 2024 The JAX Authors.
25
36
Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,7 +25,7 @@ limitations under the License.
2225

2326
namespace jax {
2427

25-
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
28+
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>;
2629

2730
template <>
2831
absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(

jaxlib/gpu/handle_pool.h

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,30 @@ limitations under the License.
2424
#include "absl/synchronization/mutex.h"
2525

2626
namespace jax {
27+
// Enum to distinguish between pool types
28+
enum class HandleKind {
29+
UNKNOWN,
30+
BLAS,
31+
SOLVER
32+
};
33+
34+
// Tag types for unique pool instantiations
35+
struct DefaultTag {};
36+
struct BlasTag {};
37+
struct SolverTag {};
2738

2839
// To avoid creating cublas/cusolver contexts in the middle of execution, we
2940
// maintain a pool of them.
30-
template <typename HandleType, typename StreamType>
41+
42+
// The Tag template parameter ensures unique pool instantiations for different
43+
// handle types (BLAS, SOLVER, etc.). Without this tag, C++ template
44+
// instantiation would create a single shared static pool when HandleType and
45+
// StreamType are the same, leading to resource contamination between different
46+
// GPU library contexts (e.g., hipBLAS and hipSOLVER sharing the same pool).
47+
template <typename HandleType, typename StreamType, typename Tag = DefaultTag>
3148
class HandlePool {
49+
HandleKind kind() const { return kind_; }
50+
void set_kind(HandleKind kind) { kind_ = kind; }
3251
public:
3352
HandlePool() = default;
3453

@@ -66,11 +85,11 @@ class HandlePool {
6685
HandleType get() { return handle_; }
6786

6887
private:
69-
friend class HandlePool<HandleType, StreamType>;
70-
Handle(HandlePool<HandleType, StreamType>* pool, HandleType handle,
88+
friend class HandlePool<HandleType, StreamType, Tag>;
89+
Handle(HandlePool<HandleType, StreamType, Tag>* pool, HandleType handle,
7190
StreamType stream)
7291
: pool_(pool), handle_(handle), stream_(stream) {}
73-
HandlePool<HandleType, StreamType>* pool_ = nullptr;
92+
HandlePool<HandleType, StreamType, Tag>* pool_ = nullptr;
7493
HandleType handle_ = nullptr;
7594
StreamType stream_ = nullptr;
7695
};
@@ -80,31 +99,40 @@ class HandlePool {
8099
static absl::StatusOr<Handle> Borrow(StreamType stream);
81100

82101
private:
83-
static HandlePool<HandleType, StreamType>* Instance();
102+
static HandlePool<HandleType, StreamType, Tag>* Instance(HandleKind kind = HandleKind::UNKNOWN);
84103

85104
void Return(HandleType handle, StreamType stream);
86105

87106
absl::Mutex mu_;
88107
std::map<StreamType, std::vector<HandleType>> handles_ ABSL_GUARDED_BY(mu_);
108+
HandleKind kind_ = HandleKind::UNKNOWN;
89109
};
90110

91-
template <typename HandleType, typename StreamType>
92-
/*static*/ HandlePool<HandleType, StreamType>*
93-
HandlePool<HandleType, StreamType>::Instance() {
94-
static auto* pool = new HandlePool<HandleType, StreamType>;
95-
return pool;
111+
template <typename HandleType, typename StreamType, typename Tag>
112+
/*static*/ HandlePool<HandleType, StreamType, Tag>*
113+
HandlePool<HandleType, StreamType, Tag>::Instance(HandleKind kind) {
114+
static std::map<HandleKind, HandlePool<HandleType, StreamType, Tag>*> pools;
115+
auto it = pools.find(kind);
116+
if (it == pools.end()) {
117+
auto* pool = new HandlePool<HandleType, StreamType, Tag>;
118+
pool->set_kind(kind);
119+
pools[kind] = pool;
120+
return pool;
121+
}
122+
return it->second;
96123
}
97124

98-
template <typename HandleType, typename StreamType>
99-
void HandlePool<HandleType, StreamType>::Return(HandleType handle,
100-
StreamType stream) {
125+
template <typename HandleType, typename StreamType, typename Tag>
126+
void HandlePool<HandleType, StreamType, Tag>::Return(HandleType handle,
127+
StreamType stream) {
101128
absl::MutexLock lock(&mu_);
102129
handles_[stream].push_back(handle);
103130
}
104131

105132
// template <typename HandleType, typename StreamType>
106133
// HandlePool<HandleType, StreamType>::Borrow(StreamType stream)
107134

135+
108136
} // namespace jax
109137

110138
#endif // JAXLIB_GPU_HANDLE_POOL_H_

jaxlib/gpu/solver_handle_pool.cc

100644100755
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "jaxlib/gpu/solver_handle_pool.h"
17-
1817
#include "absl/status/statusor.h"
1918
#include "absl/synchronization/mutex.h"
2019
#include "jaxlib/gpu/gpu_kernel_helpers.h"
@@ -30,7 +29,7 @@ namespace jax {
3029
template <>
3130
/*static*/ absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
3231
gpuStream_t stream) {
33-
SolverHandlePool* pool = Instance();
32+
SolverHandlePool* pool = Instance(HandleKind::SOLVER);
3433
absl::MutexLock lock(&pool->mu_);
3534
gpusolverDnHandle_t handle;
3635
if (pool->handles_[stream].empty()) {
@@ -40,7 +39,11 @@ template <>
4039
pool->handles_[stream].pop_back();
4140
}
4241
if (stream) {
43-
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
42+
if (pool->kind_ == HandleKind::SOLVER) {
43+
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
44+
} else {
45+
return absl::InternalError("SolverHandlePool kind is not SOLVER");
46+
}
4447
}
4548
return Handle(pool, handle, stream);
4649
}

jaxlib/gpu/solver_handle_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ limitations under the License.
2626

2727
namespace jax {
2828

29-
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
29+
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t, SolverTag>;
3030

3131
template <>
3232
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(

0 commit comments

Comments
 (0)