File tree Expand file tree Collapse file tree 1 file changed +16
-4
lines changed
Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -196,10 +196,6 @@ def get_used_memory(self):
196196 peak_memory = get_used_memory_func (
197197 tvm .device ("cuda" , 0 )
198198 ).debug_get_from_remote (0 )
199-
200- # TODO: temp hack to switch the VM allocator to eager recycling mode on all devices
201- for i in range (1 , self .num_shards ):
202- get_used_memory_func (tvm .device ("cuda" , i )).debug_get_from_remote (i )
203199 else :
204200 params = self .params
205201
@@ -230,7 +226,23 @@ def profile_memory_usage(self, seq_lens):
230226 positions = copy_to_worker_0 (self .disco_session , positions )
231227 seq_lens = copy_to_worker_0 (self .disco_session , seq_lens )
232228
229+ start_profiling_func = self .disco_session .get_global_func (
230+ "vm.memory_manager.start_profiling"
231+ )
232+ stop_profiling_func = self .disco_session .get_global_func (
233+ "vm.memory_manager.stop_profiling"
234+ )
235+ else :
236+ start_profiling_func = tvm .get_global_func (
237+ "vm.memory_manager.start_profiling"
238+ )
239+ stop_profiling_func = tvm .get_global_func (
240+ "vm.memory_manager.stop_profiling"
241+ )
242+
243+ start_profiling_func ()
233244 self .mod ["evaluate" ](input_ids , positions , seq_lens , self .params )
245+ stop_profiling_func ()
234246
235247 return self .get_used_memory ()
236248
You can’t perform that action at this time.
0 commit comments