Skip to content

Commit 4154bfe

Browse files
committed
toggle memory profiling start/stop explicitly
1 parent bd0e221 commit 4154bfe

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

serve/mlc_serve/model/tvm_model.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,6 @@ def get_used_memory(self):
196196
peak_memory = get_used_memory_func(
197197
tvm.device("cuda", 0)
198198
).debug_get_from_remote(0)
199-
200-
# TODO: temp hack to switch the VM allocator to eager recycling mode on all devices
201-
for i in range(1, self.num_shards):
202-
get_used_memory_func(tvm.device("cuda", i)).debug_get_from_remote(i)
203199
else:
204200
params = self.params
205201

@@ -230,7 +226,23 @@ def profile_memory_usage(self, seq_lens):
230226
positions = copy_to_worker_0(self.disco_session, positions)
231227
seq_lens = copy_to_worker_0(self.disco_session, seq_lens)
232228

229+
start_profiling_func = self.disco_session.get_global_func(
230+
"vm.memory_manager.start_profiling"
231+
)
232+
stop_profiling_func = self.disco_session.get_global_func(
233+
"vm.memory_manager.stop_profiling"
234+
)
235+
else:
236+
start_profiling_func = tvm.get_global_func(
237+
"vm.memory_manager.start_profiling"
238+
)
239+
stop_profiling_func = tvm.get_global_func(
240+
"vm.memory_manager.stop_profiling"
241+
)
242+
243+
start_profiling_func()
233244
self.mod["evaluate"](input_ids, positions, seq_lens, self.params)
245+
stop_profiling_func()
234246

235247
return self.get_used_memory()
236248

0 commit comments

Comments
 (0)