diff --git a/senta/train.py b/senta/train.py index f0c4056..8789241 100644 --- a/senta/train.py +++ b/senta/train.py @@ -180,7 +180,8 @@ def get_support_task(self): tasks = list(self._params.get("task_name").keys()) return tasks - def init_model(self, model_class="ernie_1.0_skep_large_ch", task="sentiment_classify", use_cuda=False): + def init_model( + self, model_class="ernie_1.0_skep_large_ch", task="sentiment_classify", use_cuda=False, load_locally=False): """ init_model """ @@ -192,7 +193,8 @@ def init_model(self, model_class="ernie_1.0_skep_large_ch", task="sentiment_clas # step 1: get_init_model, if download data_url = model_dict.get("model_file_http_url") md5_url = model_dict.get("model_md5_http_url") - is_download_data = download_data(data_url, md5_url) + if not load_locally: + is_download_data = download_data(data_url, md5_url) # step 2 get model_class register.import_modules() diff --git a/setup.py b/setup.py index 2821dbd..845508d 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ setuptools.setup( name="Senta", - version="2.0.0", + version="2.0.1", author="Baidu NLP", author_email="gaocan01@baidu.com", description="A sentiment classification tools made by Baidu NLP.",