@@ -70,11 +70,7 @@ uct_cuda_ipc_get_dev_cache(uct_cuda_ipc_component_t *component,
70
70
int ret ;
71
71
72
72
key .uuid = rkey -> uuid ;
73
- #if HAVE_CUDA_FABRIC
74
73
key .type = rkey -> ph .handle_type ;
75
- #else
76
- key .type = 0 ;
77
- #endif
78
74
79
75
iter = kh_put (cuda_ipc_uuid_hash , hash , key , & ret );
80
76
if (ret == UCS_KH_PUT_KEY_PRESENT ) {
@@ -112,11 +108,10 @@ static ucs_status_t
112
108
uct_cuda_ipc_mem_add_reg (void * addr , uct_cuda_ipc_memh_t * memh ,
113
109
uct_cuda_ipc_lkey_t * * key_p )
114
110
{
115
- CUipcMemHandle * legacy_handle ;
116
111
uct_cuda_ipc_lkey_t * key ;
117
112
ucs_status_t status ;
118
113
#if HAVE_CUDA_FABRIC
119
- #define UCT_CUDA_IPC_QUERY_NUM_ATTRS 2
114
+ #define UCT_CUDA_IPC_QUERY_NUM_ATTRS 3
120
115
CUmemGenericAllocationHandle handle ;
121
116
CUmemoryPool mempool ;
122
117
CUpointer_attribute attr_type [UCT_CUDA_IPC_QUERY_NUM_ATTRS ];
@@ -130,10 +125,16 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
130
125
return UCS_ERR_NO_MEMORY ;
131
126
}
132
127
133
- legacy_handle = (CUipcMemHandle * )& key -> ph ;
134
128
UCT_CUDADRV_FUNC_LOG_ERR (cuMemGetAddressRange (& key -> d_bptr , & key -> b_len ,
135
129
(CUdeviceptr )addr ));
136
130
131
+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuPointerGetAttribute (& key -> ph .buffer_id ,
132
+ CU_POINTER_ATTRIBUTE_BUFFER_ID ,
133
+ (CUdeviceptr )addr ));
134
+ if (status != UCS_OK ) {
135
+ goto err ;
136
+ }
137
+
137
138
#if HAVE_CUDA_FABRIC
138
139
/* cuda_ipc can handle VMM, mallocasync, and legacy pinned device so need to
139
140
* pack appropriate handle */
@@ -142,6 +143,8 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
142
143
attr_data [0 ] = & legacy_capable ;
143
144
attr_type [1 ] = CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES ;
144
145
attr_data [1 ] = & allowed_handle_types ;
146
+ attr_type [2 ] = CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE ;
147
+ attr_data [2 ] = & mempool ;
145
148
146
149
status = UCT_CUDADRV_FUNC_LOG_ERR (
147
150
cuPointerGetAttributes (ucs_static_array_size (attr_data ), attr_type ,
@@ -151,8 +154,6 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
151
154
}
152
155
153
156
if (legacy_capable ) {
154
- key -> ph .handle_type = UCT_CUDA_IPC_KEY_HANDLE_TYPE_LEGACY ;
155
- legacy_handle = & key -> ph .handle .legacy ;
156
157
goto legacy_path ;
157
158
}
158
159
@@ -184,9 +185,7 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
184
185
goto common_path ;
185
186
}
186
187
187
- status = UCT_CUDADRV_FUNC_LOG_ERR (cuPointerGetAttribute (& mempool ,
188
- CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE , (CUdeviceptr )addr ));
189
- if ((status != UCS_OK ) || (mempool == 0 )) {
188
+ if (mempool == 0 ) {
190
189
/* cuda_ipc can only handle UCS_MEMORY_TYPE_CUDA, which has to be either
191
190
* legacy type, or VMM type, or mempool type. Return error if memory
192
191
* does not belong to any of the three types */
@@ -216,16 +215,18 @@ uct_cuda_ipc_mem_add_reg(void *addr, uct_cuda_ipc_memh_t *memh,
216
215
goto common_path ;
217
216
#endif
218
217
legacy_path :
219
- status = UCT_CUDADRV_FUNC (cuIpcGetMemHandle (legacy_handle , (CUdeviceptr )addr ),
220
- UCS_LOG_LEVEL_ERROR );
218
+ key -> ph .handle_type = UCT_CUDA_IPC_KEY_HANDLE_TYPE_LEGACY ;
219
+ status = UCT_CUDADRV_FUNC_LOG_ERR (
220
+ cuIpcGetMemHandle (& key -> ph .handle .legacy , (CUdeviceptr )addr ));
221
221
if (status != UCS_OK ) {
222
222
goto err ;
223
223
}
224
224
225
225
common_path :
226
226
ucs_list_add_tail (& memh -> list , & key -> link );
227
- ucs_trace ("registered addr:%p/%p length:%zd dev_num:%d" ,
228
- addr , (void * )key -> d_bptr , key -> b_len , (int )memh -> dev_num );
227
+ ucs_trace ("registered addr:%p/%p length:%zd dev_num:%d buffer_id:%llu" ,
228
+ addr , (void * )key -> d_bptr , key -> b_len , (int )memh -> dev_num ,
229
+ key -> ph .buffer_id );
229
230
230
231
* key_p = key ;
231
232
return UCS_OK ;
0 commit comments