Skip to content

Commit

Permalink
typing: add initial types (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnfraney authored Jan 20, 2025
1 parent 909d4c8 commit 5765091
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 96 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Source = "https://github.com/sloria/TextBlob"
[project.optional-dependencies]
docs = ["sphinx==8.1.3", "sphinx-issues==5.0.0", "PyYAML==6.0.2"]
tests = ["pytest", "numpy"]
dev = ["textblob[tests]", "tox", "pre-commit>=3.5,<5.0"]
dev = ["textblob[tests]", "tox", "pre-commit>=3.5,<5.0", "pyright", "ruff"]

[build-system]
requires = ["flit_core<4"]
Expand Down Expand Up @@ -86,6 +86,7 @@ select = [
"I", # isort
"UP", # pyupgrade
"W", # pycodestyle warning
"TC", # flake8-typechecking
]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -96,3 +97,6 @@ markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"numpy: marks tests that require numpy",
]

[tool.pyright]
include = ["src/**", "tests/**"]
23 changes: 11 additions & 12 deletions src/textblob/_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def keys(self):
def values(self):
return self._lazy("values")

def update(self, *args):
def update(self, *args, **kwargs):
return self._lazy("update", *args)

def pop(self, *args):
Expand Down Expand Up @@ -324,10 +324,10 @@ def penntreebank2universal(token, tag):
("cry", -1.00): set((":'(", ":'''(", ";'(")),
}

RE_EMOTICONS = [
TEMP_RE_EMOTICONS = [
r" ?".join([re.escape(each) for each in e]) for v in EMOTICONS.values() for e in v
]
RE_EMOTICONS = re.compile(r"(%s)($|\s)" % "|".join(RE_EMOTICONS))
RE_EMOTICONS = re.compile(r"(%s)($|\s)" % "|".join(TEMP_RE_EMOTICONS))

# Handle sarcasm punctuation (!).
RE_SARCASM = re.compile(r"\( ?\! ?\)")
Expand Down Expand Up @@ -490,9 +490,9 @@ class Lexicon(lazydict):
def __init__(
self,
path="",
morphology=None,
context=None,
entities=None,
morphology="",
context="",
entities="",
NNP="NNP",
language=None,
):
Expand Down Expand Up @@ -724,7 +724,7 @@ def apply(self, tokens):
t[i] = [t[i][0], r[1]]
return t[len(o) : -len(o)]

def insert(self, i, tag1, tag2, cmd="prevtag", x=None, y=None):
def insert(self, i, tag1, tag2, cmd="prevtag", x=None, y=None, *args):
"""Inserts a new rule that updates words with tag1 to tag2,
given constraints x and y, e.g., Context.append("TO < NN", "VB")
"""
Expand All @@ -739,7 +739,7 @@ def insert(self, i, tag1, tag2, cmd="prevtag", x=None, y=None):
def append(self, *args, **kwargs):
self.insert(len(self) - 1, *args, **kwargs)

def extend(self, rules=None):
def extend(self, rules=None, *args):
if rules is None:
rules = []
for r in rules:
Expand Down Expand Up @@ -1570,9 +1570,8 @@ def parse(

TOKENS = "tokens"


class TaggedString(str):
def __new__(self, string, tags=None, language=None):
def __new__(cls, string, tags=None, language=None):
"""Unicode string with tags and language attributes.
For example: TaggedString("cat/NN/NP", tags=["word", "pos", "chunk"]).
"""
Expand All @@ -1588,7 +1587,7 @@ def __new__(self, string, tags=None, language=None):
for s in string
]
string = "\n".join(" ".join("/".join(token) for token in s) for s in string)
s = str.__new__(self, string)
s = str.__new__(cls, string)
s.tags = list(tags)
s.language = language
return s
Expand Down Expand Up @@ -1634,7 +1633,7 @@ def language(self):
return self._language

@classmethod
def train(self, s, path="spelling.txt"):
def train(cls, s, path="spelling.txt"):
"""Counts the words in the given string and saves the probabilities at the given path.
This can be used to generate a new model for the Spelling() constructor.
"""
Expand Down
30 changes: 19 additions & 11 deletions src/textblob/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
All base classes are defined in the same module, ``textblob.base``.
"""

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING

import nltk

if TYPE_CHECKING:
from typing import Any, AnyStr

##### POS TAGGERS #####


Expand All @@ -19,11 +25,11 @@ class BaseTagger(metaclass=ABCMeta):
"""

@abstractmethod
def tag(self, text, tokenize=True):
def tag(self, text: str, tokenize=True) -> list[tuple[str, str]]:
"""Return a list of tuples of the form (word, tag)
for a given set of text or BaseBlob instance.
"""
return
...


##### NOUN PHRASE EXTRACTORS #####
Expand All @@ -36,29 +42,29 @@ class BaseNPExtractor(metaclass=ABCMeta):
"""

@abstractmethod
def extract(self, text):
def extract(self, text: str) -> list[str]:
"""Return a list of noun phrases (strings) for a body of text."""
return
...


##### TOKENIZERS #####


class BaseTokenizer(nltk.tokenize.api.TokenizerI, metaclass=ABCMeta):
class BaseTokenizer(nltk.tokenize.api.TokenizerI, metaclass=ABCMeta): # pyright: ignore
"""Abstract base class from which all Tokenizer classes inherit.
Descendant classes must implement a ``tokenize(text)`` method
that returns a list of noun phrases as strings.
"""

@abstractmethod
def tokenize(self, text):
def tokenize(self, text: str) -> list[str]:
"""Return a list of tokens (strings) for a body of text.
:rtype: list
"""
return
...

def itokenize(self, text, *args, **kwargs):
def itokenize(self, text: str, *args, **kwargs):
"""Return a generator that generates tokens "on-demand".
.. versionadded:: 0.6.0
Expand All @@ -81,6 +87,8 @@ class BaseSentimentAnalyzer(metaclass=ABCMeta):
results of analysis.
"""

_trained: bool

kind = DISCRETE

def __init__(self):
Expand All @@ -91,7 +99,7 @@ def train(self):
self._trained = True

@abstractmethod
def analyze(self, text):
def analyze(self, text) -> Any:
"""Return the result of of analysis. Typically returns either a
tuple, float, or dictionary.
"""
Expand All @@ -111,6 +119,6 @@ class BaseParser(metaclass=ABCMeta):
"""

@abstractmethod
def parse(self, text):
def parse(self, text: AnyStr):
"""Parses the text."""
return
...
8 changes: 4 additions & 4 deletions src/textblob/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def lemmatize(self, pos=None):
lemmatizer = nltk.stem.WordNetLemmatizer()
return lemmatizer.lemmatize(self.string, tag)

PorterStemmer = nltk.stem.porter.PorterStemmer()
LancasterStemmer = nltk.stem.lancaster.LancasterStemmer()
SnowballStemmer = nltk.stem.snowball.SnowballStemmer("english")
PorterStemmer = nltk.stem.PorterStemmer()
LancasterStemmer = nltk.stem.LancasterStemmer()
SnowballStemmer = nltk.stem.SnowballStemmer("english")

# added 'stemmer' on lines of lemmatizer
# based on nltk
Expand Down Expand Up @@ -308,7 +308,7 @@ def _initialize_models(
obj.tokenizer = _validated_param(
tokenizer,
"tokenizer",
base_class=(BaseTokenizer, nltk.tokenize.api.TokenizerI),
base_class=(BaseTokenizer, nltk.tokenize.api.TokenizerI), # pyright: ignore
default=BaseBlob.tokenizer,
base_class_name="BaseTokenizer",
)
Expand Down
4 changes: 2 additions & 2 deletions src/textblob/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ def update(


class MaxEntClassifier(NLTKClassifier):
__doc__ = nltk.classify.maxent.MaxentClassifier.__doc__
nltk_class = nltk.classify.maxent.MaxentClassifier
__doc__ = nltk.classify.MaxentClassifier.__doc__
nltk_class = nltk.classify.MaxentClassifier

def prob_classify(self, text):
"""Return the label probability distribution for classifying a string
Expand Down
13 changes: 12 additions & 1 deletion src/textblob/decorators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
"""Custom decorators."""

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING

from textblob.exceptions import MissingCorpusError

if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeVar

ReturnType = TypeVar("ReturnType")


class cached_property:
"""A property that is only computed once per instance and then replaces
Expand All @@ -24,7 +33,9 @@ def __get__(self, obj, cls):
return value


def requires_nltk_corpus(func):
def requires_nltk_corpus(
func: Callable[..., ReturnType],
) -> Callable[..., ReturnType]:
"""Wraps a function that requires an NLTK corpus. If the corpus isn't found,
raise a :exc:`MissingCorpusError`.
"""
Expand Down
Loading

0 comments on commit 5765091

Please sign in to comment.