Skip to content

Commit 33c4da7

Browse files
committed
Fix GPU handle pool singleton aliasing with tag-based template separation
Problem: - BLAS and SOLVER handle pools were sharing the same static singleton instance - C++ template instantiation created only one pool when HandleType/StreamType were identical - This caused resource contamination between hipBLAS and hipSOLVER operations - Led to potential memory corruption and unexpected GPU library behavior Solution: - Added Tag template parameter to HandlePool class with default DefaultTag - Introduced BlasTag and SolverTag for unique pool instantiations - Updated all handle pool typedefs to use three-argument template - Each tag type now gets its own static singleton pool instance
1 parent 3d6b521 commit 33c4da7

File tree

5 files changed

+62
-22
lines changed

5 files changed

+62
-22
lines changed

jaxlib/gpu/blas_handle_pool.cc

100644100755
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ limitations under the License.
2424
namespace jax {
2525

2626
template <>
27-
/*static*/ absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(
28-
gpuStream_t stream) {
29-
BlasHandlePool* pool = Instance();
27+
/*static*/ absl::StatusOr<HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>::Handle>
28+
HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>::Borrow(gpuStream_t stream) {
29+
auto* pool = Instance(HandleKind::BLAS);
3030
absl::MutexLock lock(&pool->mu_);
3131
gpublasHandle_t handle;
3232
if (pool->handles_[stream].empty()) {
@@ -36,7 +36,11 @@ template <>
3636
pool->handles_[stream].pop_back();
3737
}
3838
if (stream) {
39-
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
39+
if (pool->kind_ == HandleKind::BLAS) {
40+
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpublasSetStream(handle, stream)));
41+
} else {
42+
return absl::InternalError("BlasHandlePool kind is not BLAS");
43+
}
4044
}
4145
return Handle(pool, handle, stream);
4246
}

jaxlib/gpu/blas_handle_pool.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
// Tag type for BLAS pool
3+
struct BlasTag {};
14
/* Copyright 2024 The JAX Authors.
25
36
Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,7 +25,7 @@ limitations under the License.
2225

2326
namespace jax {
2427

25-
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t>;
28+
using BlasHandlePool = HandlePool<gpublasHandle_t, gpuStream_t, BlasTag>;
2629

2730
template <>
2831
absl::StatusOr<BlasHandlePool::Handle> BlasHandlePool::Borrow(

jaxlib/gpu/handle_pool.h

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,38 @@ limitations under the License.
1818

1919
#include <map>
2020
#include <vector>
21+
#include <iostream>
22+
#include <typeinfo>
2123

2224
#include "absl/base/thread_annotations.h"
2325
#include "absl/status/statusor.h"
2426
#include "absl/synchronization/mutex.h"
2527

2628
namespace jax {
29+
// Enum to distinguish between pool types
30+
enum class HandleKind {
31+
UNKNOWN,
32+
BLAS,
33+
SOLVER
34+
};
35+
36+
// Tag types for unique pool instantiations
37+
struct DefaultTag {};
38+
struct BlasTag {};
39+
struct SolverTag {};
2740

2841
// To avoid creating cublas/cusolver contexts in the middle of execution, we
2942
// maintain a pool of them.
30-
template <typename HandleType, typename StreamType>
43+
44+
// The Tag template parameter ensures unique pool instantiations for different
45+
// handle types (BLAS, SOLVER, etc.). Without this tag, C++ template
46+
// instantiation would create a single shared static pool when HandleType and
47+
// StreamType are the same, leading to resource contamination between different
48+
// GPU library contexts (e.g., hipBLAS and hipSOLVER sharing the same pool).
49+
template <typename HandleType, typename StreamType, typename Tag = DefaultTag>
3150
class HandlePool {
51+
HandleKind kind() const { return kind_; }
52+
void set_kind(HandleKind kind) { kind_ = kind; }
3253
public:
3354
HandlePool() = default;
3455

@@ -66,11 +87,11 @@ class HandlePool {
6687
HandleType get() { return handle_; }
6788

6889
private:
69-
friend class HandlePool<HandleType, StreamType>;
70-
Handle(HandlePool<HandleType, StreamType>* pool, HandleType handle,
90+
friend class HandlePool<HandleType, StreamType, Tag>;
91+
Handle(HandlePool<HandleType, StreamType, Tag>* pool, HandleType handle,
7192
StreamType stream)
7293
: pool_(pool), handle_(handle), stream_(stream) {}
73-
HandlePool<HandleType, StreamType>* pool_ = nullptr;
94+
HandlePool<HandleType, StreamType, Tag>* pool_ = nullptr;
7495
HandleType handle_ = nullptr;
7596
StreamType stream_ = nullptr;
7697
};
@@ -80,31 +101,40 @@ class HandlePool {
80101
static absl::StatusOr<Handle> Borrow(StreamType stream);
81102

82103
private:
83-
static HandlePool<HandleType, StreamType>* Instance();
104+
static HandlePool<HandleType, StreamType, Tag>* Instance(HandleKind kind = HandleKind::UNKNOWN);
84105

85106
void Return(HandleType handle, StreamType stream);
86107

87108
absl::Mutex mu_;
88109
std::map<StreamType, std::vector<HandleType>> handles_ ABSL_GUARDED_BY(mu_);
110+
HandleKind kind_ = HandleKind::UNKNOWN;
89111
};
90112

91-
template <typename HandleType, typename StreamType>
92-
/*static*/ HandlePool<HandleType, StreamType>*
93-
HandlePool<HandleType, StreamType>::Instance() {
94-
static auto* pool = new HandlePool<HandleType, StreamType>;
95-
return pool;
113+
template <typename HandleType, typename StreamType, typename Tag>
114+
/*static*/ HandlePool<HandleType, StreamType, Tag>*
115+
HandlePool<HandleType, StreamType, Tag>::Instance(HandleKind kind) {
116+
static std::map<HandleKind, HandlePool<HandleType, StreamType, Tag>*> pools;
117+
auto it = pools.find(kind);
118+
if (it == pools.end()) {
119+
auto* pool = new HandlePool<HandleType, StreamType, Tag>;
120+
pool->set_kind(kind);
121+
pools[kind] = pool;
122+
return pool;
123+
}
124+
return it->second;
96125
}
97126

98-
template <typename HandleType, typename StreamType>
99-
void HandlePool<HandleType, StreamType>::Return(HandleType handle,
100-
StreamType stream) {
127+
template <typename HandleType, typename StreamType, typename Tag>
128+
void HandlePool<HandleType, StreamType, Tag>::Return(HandleType handle,
129+
StreamType stream) {
101130
absl::MutexLock lock(&mu_);
102131
handles_[stream].push_back(handle);
103132
}
104133

105134
// template <typename HandleType, typename StreamType>
106135
// HandlePool<HandleType, StreamType>::Borrow(StreamType stream)
107136

137+
108138
} // namespace jax
109139

110140
#endif // JAXLIB_GPU_HANDLE_POOL_H_

jaxlib/gpu/solver_handle_pool.cc

100644100755
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "jaxlib/gpu/solver_handle_pool.h"
17-
1817
#include "absl/status/statusor.h"
1918
#include "absl/synchronization/mutex.h"
2019
#include "jaxlib/gpu/gpu_kernel_helpers.h"
@@ -30,7 +29,7 @@ namespace jax {
3029
template <>
3130
/*static*/ absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
3231
gpuStream_t stream) {
33-
SolverHandlePool* pool = Instance();
32+
SolverHandlePool* pool = Instance(HandleKind::SOLVER);
3433
absl::MutexLock lock(&pool->mu_);
3534
gpusolverDnHandle_t handle;
3635
if (pool->handles_[stream].empty()) {
@@ -40,7 +39,11 @@ template <>
4039
pool->handles_[stream].pop_back();
4140
}
4241
if (stream) {
43-
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
42+
if (pool->kind_ == HandleKind::SOLVER) {
43+
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverDnSetStream(handle, stream)));
44+
} else {
45+
return absl::InternalError("SolverHandlePool kind is not SOLVER");
46+
}
4447
}
4548
return Handle(pool, handle, stream);
4649
}

jaxlib/gpu/solver_handle_pool.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ limitations under the License.
2626

2727
namespace jax {
2828

29-
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t>;
29+
using SolverHandlePool = HandlePool<gpusolverDnHandle_t, gpuStream_t, SolverTag>;
3030

3131
template <>
3232
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(

0 commit comments

Comments
 (0)