From ddfbcda4ebd82058855443a5b26a9011ddc026fc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 21 Mar 2024 16:45:15 -0400 Subject: [PATCH] [iOS][Android] Add validation of library file for iOS and Android build (#1993) This PR adds validation of symbols in iOS and android build. During static library build, we need the right model_lib for us to point to the packaged model executables. Not doing so correctly will results in vm_load_executable not found which is not informative. This PR we validate the compiled model lib by dumping the global symbols and ensure the list of model libs matches with each other. In future we should perhaps lift the validation to mlc_llm package. --- android/library/prepare_model_lib.py | 62 +++++++++++++++++- .../library/src/main/assets/app-config.json | 12 ++-- ios/prepare_model_lib.py | 64 ++++++++++++++++++- 3 files changed, 126 insertions(+), 12 deletions(-) diff --git a/android/library/prepare_model_lib.py b/android/library/prepare_model_lib.py index 9363be74c8..dc14397a16 100644 --- a/android/library/prepare_model_lib.py +++ b/android/library/prepare_model_lib.py @@ -3,20 +3,76 @@ from tvm.contrib import ndk +def get_model_libs(lib_path): + global_symbol_map = ndk.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name in global_symbol_map.keys(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + def main(): - app_config = json.load(open("src/main/assets/app-config.json", "r")) + app_config_path = "src/main/assets/app-config.json" + app_config = json.load(open(app_config_path, "r")) artifact_path = os.path.abspath(os.path.join("../..", "dist")) tar_list = [] + model_set = set() - for model_lib_path in app_config["model_lib_path_for_prepare_libs"].values(): + for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): path = os.path.join(artifact_path, model_lib_path) if not os.path.isfile(path): raise RuntimeError(f"Cannot find android library {path}") tar_list.append(path) + model_set.add(model) - ndk.create_staticlib(os.path.join("build", "model_lib", "libmodel_android.a"), tar_list) + lib_path = os.path.join("build", "model_lib", "libmodel_android.a") + ndk.create_staticlib(lib_path, tar_list) print(f"Creating lib from {tar_list}..") + available_model_libs = get_model_libs(lib_path) + print(f"Validating the library {lib_path}...") + print( + f"List of available model libs packaged: {available_model_libs}," + " if we have '-' in the model_lib string, it will be turned into '_'" + ) + global_symbol_map = ndk.get_global_symbol_section_map(lib_path) + error_happened = False + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + print( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs field, " + "This will cause the specific model not being able to load, " + f"please check {app_config_path}." + ) + error_happened = True + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + print( + "ValidationError:\n" + f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" + f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" + f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" + ) + error_happened = True + + if not error_happened: + print("Validation pass") + else: + print("Validation failed") + exit(255) + if __name__ == "__main__": main() diff --git a/android/library/src/main/assets/app-config.json b/android/library/src/main/assets/app-config.json index 8dcdf6dabf..68442c234e 100644 --- a/android/library/src/main/assets/app-config.json +++ b/android/library/src/main/assets/app-config.json @@ -26,16 +26,16 @@ }, { "model_url": "https://huggingface.co/mlc-ai/phi-2-q4f16_1-MLC", - "model_lib": "phi_q4f16_1", + "model_lib": "phi_msft_q4f16_1", "estimated_vram_bytes": 2036816936, "model_id": "phi-2-q4f16_1" } ], "model_lib_path_for_prepare_libs": { - "gemma_q4f16_1": "prebuilt_libs/gemma-2b-it/gemma-2b-it-q4f16_1-android.tar", - "llama_q4f16_1": "prebuilt_libs/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-android.tar", - "gpt_neox_q4f16_1": "prebuilt_libs/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar", - "phi_q4f16_1": "prebuilt_libs/phi-2/phi-2-q4f16_1-android.tar", - "Mistral-7B-Instruct-v0.2-q4f16_1": "prebuilt_libs/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar" + "gemma_q4f16_1": "prebuilt/lib/gemma-2b-it/gemma-2b-it-q4f16_1-android.tar", + "llama_q4f16_1": "prebuilt/lib/Llama-2-7b-chat-hf/Llama-2-7b-chat-hf-q4f16_1-android.tar", + "gpt_neox_q4f16_1": "prebuilt/lib/RedPajama-INCITE-Chat-3B-v1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar", + "phi_msft_q4f16_1": "prebuilt/lib/phi-2/phi-2-q4f16_1-android.tar", + "mistral_q4f16_1": "prebuilt/lib/Mistral-7B-Instruct-v0.2/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar" } } \ No newline at end of file diff --git a/ios/prepare_model_lib.py b/ios/prepare_model_lib.py index 1db56cd08a..0e66879ddc 100644 --- a/ios/prepare_model_lib.py +++ b/ios/prepare_model_lib.py @@ -1,13 +1,29 @@ import json import os +import sys from tvm.contrib import cc +def get_model_libs(lib_path): + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + libs = [] + suffix = "___tvm_dev_mblob" + for name in global_symbol_map.keys(): + if name.endswith(suffix): + model_lib = name[: -len(suffix)] + if model_lib.startswith("_"): + model_lib = model_lib[1:] + libs.append(model_lib) + return libs + + def main(): - app_config = json.load(open("MLCChat/app-config.json", "r")) + app_config_path = "MLCChat/app-config.json" + app_config = json.load(open(app_config_path, "r")) artifact_path = os.path.abspath(os.path.join("..", "dist")) tar_list = [] + model_set = set() for model, model_lib_path in app_config["model_lib_path_for_prepare_libs"].items(): paths = [ @@ -20,10 +36,52 @@ def main(): raise RuntimeError( f"Cannot find iOS lib for {model} from the following candidate paths: {paths}" ) - tar_list.append(valid_paths[0]) + tar_list.append(valid_paths[ls0]) + model_set.add(model) - cc.create_staticlib(os.path.join("build", "lib", "libmodel_iphone.a"), tar_list) + lib_path = os.path.join("build", "lib", "libmodel_iphone.a") + + cc.create_staticlib(lib_path, tar_list) + available_model_libs = get_model_libs(lib_path) print(f"Creating lib from {tar_list}..") + print(f"Validating the library {lib_path}...") + print( + f"List of available model libs packaged: {available_model_libs}," + " if we have '-' in the model_lib string, it will be turned into '_'" + ) + global_symbol_map = cc.get_global_symbol_section_map(lib_path) + error_happened = False + for item in app_config["model_list"]: + model_lib = item["model_lib"] + model_id = item["model_id"] + if model_lib not in model_set: + print( + f"ValidationError: model_lib={model_lib} specified for model_id={model_id} " + "is not included in model_lib_path_for_prepare_libs field, " + "This will cause the specific model not being able to load, " + f"please check {app_config_path}." + ) + error_happened = True + + model_prefix_pattern = model_lib.replace("-", "_") + "___tvm_dev_mblob" + if ( + model_prefix_pattern not in global_symbol_map + and "_" + model_prefix_pattern not in global_symbol_map + ): + model_lib_path = app_config["model_lib_path_for_prepare_libs"][model_lib] + print( + "ValidationError:\n" + f"\tmodel_lib {model_lib} requested in {app_config_path} is not found in {lib_path}\n" + f"\tspecifically the model_lib for {model_lib_path} in model_lib_path_for_prepare_libs.\n" + f"\tcurrent available model_libs in {lib_path}: {available_model_libs}" + ) + error_happened = True + + if not error_happened: + print("Validation pass") + else: + print("Validation failed") + exit(255) if __name__ == "__main__":