diff --git a/.gitignore b/.gitignore index fef0fd4..44a9087 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Model downloads +models/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -139,4 +142,4 @@ cython_debug/ #others nohup.out -.DS_Store \ No newline at end of file +.DS_Store diff --git a/pyproject.toml b/pyproject.toml index b37c972..59ce077 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,17 +1,18 @@ -[metadata] +[project] name = 'zeroshot_topics' version = '0.0.1' description = 'Topic Inference with Zeroshot models' author = 'AnjanaRita' author_email = 'ritaanjana1993@gmail.com' -license = 'MIT/Apache-2.0' +license = 'MIT' url = 'https://github.com/AnjanaRita/zeroshot_topics' [requires] python_version = ['2.7', '3.5', '3.6', 'pypy', 'pypy3'] [build-system] -requires = ['setuptools', 'wheel'] +build-backend = 'hatchling.build' +requires = ['hatchling', 'setuptools', 'wheel'] [tool.hatch.commands] prerelease = 'hatch build' diff --git a/requirements.txt b/requirements.txt index 9cf8170..1e006e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -attrs==20.3.0 -nltk==3.6.5 -keybert==0.5.0 -transformers==4.11.0 \ No newline at end of file +attrs==22.1.0 +nltk==3.7 +keybert==0.7.0 +transformers==4.25.1 diff --git a/zeroshot_topics/zeroshot_tm.py b/zeroshot_topics/zeroshot_tm.py index 7555c6c..2318b74 100644 --- a/zeroshot_topics/zeroshot_tm.py +++ b/zeroshot_topics/zeroshot_tm.py @@ -7,7 +7,8 @@ from nltk.corpus import wordnet as wn -classifier = load_zeroshot_model() +classifier = None +# load_zeroshot_model() @attr.s @@ -22,10 +23,15 @@ class ZeroShotTopicFinder: """ model = attr.ib(default='all-MiniLM-L6-v2') + classifier_model = attr.ib(default='valhalla/distilbart-mnli-12-6') + def __attrs_post_init__(self): self.model = KeyBERT(self.model) + global classifier + classifier = classifier or load_zeroshot_model(self.classifier_model) + def find_topic(self, text, n_topic=2): """ Infer the topic in a given string.