diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index f20370275b..653da250d8 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -196,10 +196,6 @@ def get_used_memory(self): peak_memory = get_used_memory_func( tvm.device("cuda", 0) ).debug_get_from_remote(0) - - # TODO: temp hack to switch the VM allocator to eager recycling mode on all devices - for i in range(1, self.num_shards): - get_used_memory_func(tvm.device("cuda", i)).debug_get_from_remote(i) else: params = self.params @@ -230,7 +226,23 @@ def profile_memory_usage(self, seq_lens): positions = copy_to_worker_0(self.disco_session, positions) seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + start_profiling_func = self.disco_session.get_global_func( + "vm.memory_manager.start_profiling" + ) + stop_profiling_func = self.disco_session.get_global_func( + "vm.memory_manager.stop_profiling" + ) + else: + start_profiling_func = tvm.get_global_func( + "vm.memory_manager.start_profiling" + ) + stop_profiling_func = tvm.get_global_func( + "vm.memory_manager.stop_profiling" + ) + + start_profiling_func() self.mod["evaluate"](input_ids, positions, seq_lens, self.params) + stop_profiling_func() return self.get_used_memory()