Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement new end-to-end parser class #227

Merged
merged 19 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,22 @@ jobs:
run: pip install .[extras] .[test]
- name: Locate bobcat pre-trained model cache
id: loc-bobcat-cache
run: echo "dir=$(python -c 'from lambeq.text2diagram.model_downloader import ModelDownloader; print(ModelDownloader("bert").model_dir)')" >> $GITHUB_OUTPUT
run: echo "dir=$(python -c 'from lambeq.text2diagram.model_based_reader.model_downloader import ModelDownloader; print(ModelDownloader("bobcat").model_dir)')" >> $GITHUB_OUTPUT
- name: Restore bobcat pre-trained model from cache
id: bobcat-cache
uses: actions/cache@v4
with:
path: ${{ steps.loc-bobcat-cache.outputs.dir }}
key: bobcat-bert-v1
key: bobcat-v1
- name: Locate oncilla pre-trained model cache
id: loc-oncilla-cache
run: echo "dir=$(python -c 'from lambeq.text2diagram.model_based_reader.model_downloader import ModelDownloader; print(ModelDownloader("oncilla").model_dir)')" >> $GITHUB_OUTPUT
- name: Restore oncilla pre-trained model from cache
id: oncilla-cache
uses: actions/cache@v4
with:
path: ${{ steps.loc-oncilla-cache.outputs.dir }}
key: oncilla-v1
- name: Test with pytest
run: >
coverage run --source=${{ env.SRC_DIR }}
Expand Down
3 changes: 3 additions & 0 deletions lambeq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
'CCGBankParser',
'DepCCGParseError',
'DepCCGParser',
'OncillaParseError',
'OncillaParser',
'WebParseError',
'WebParser',

Expand Down Expand Up @@ -122,6 +124,7 @@
CCGType, CCGRule, CCGRuleUseError, CCGTree,
CCGParser,
BobcatParseError, BobcatParser,
OncillaParseError, OncillaParser,
CCGBankParseError, CCGBankParser,
DepCCGParseError, DepCCGParser,
WebParseError, WebParser,
Expand Down
3 changes: 0 additions & 3 deletions lambeq/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
'Ty',
'Word',

'PregroupTreeNode',

'draw',
'draw_equation',
'to_gif',
Expand All @@ -35,6 +33,5 @@

from lambeq.backend.grammar import (Box, Cap, Category, Cup, Diagram,
Frame, Functor, Id, Spider, Swap, Ty, Word)
from lambeq.backend.pregroup_tree import PregroupTreeNode
from lambeq.backend.symbol import lambdify, Symbol
from lambeq.backend.drawing import draw, draw_equation, to_gif
2 changes: 1 addition & 1 deletion lambeq/backend/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

if TYPE_CHECKING:
import discopy
from lambeq.backend.pregroup_tree import PregroupTreeNode
from lambeq.text2diagram.pregroup_tree import PregroupTreeNode


@dataclass
Expand Down
9 changes: 9 additions & 0 deletions lambeq/bobcat/tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CCG tagger model
================

Model for tagging text with CCG tags. This work is based on
the PyTorch BERT model available in Huggingface transformers
(https://huggingface.co/transformers) which is released under
Apache License 2.0.
"""

from __future__ import annotations

Expand Down
2 changes: 1 addition & 1 deletion lambeq/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
from lambeq.backend import grammar, tensor
from lambeq.rewrite import RemoveSwapsRewriter
from lambeq.text2diagram.base import Reader
from lambeq.text2diagram.bobcat_parser import BobcatParser
from lambeq.text2diagram.ccg_parser import CCGParser
from lambeq.text2diagram.ccg_tree import CCGTree
from lambeq.text2diagram.depccg_parser import DepCCGParser
from lambeq.text2diagram.linear_reader import (cups_reader,
stairs_reader)
from lambeq.text2diagram.model_based_reader import BobcatParser
from lambeq.text2diagram.spiders_reader import spiders_reader
from lambeq.text2diagram.tree_reader import TreeReader
from lambeq.tokeniser import SpacyTokeniser
Expand Down
6 changes: 4 additions & 2 deletions lambeq/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from typing import Any, List, Union


SentenceType = Union[str, List[str]]
SentenceBatchType = Union[List[str], List[List[str]]]
TokenisedSentenceType = List[str]
SentenceType = Union[str, TokenisedSentenceType]
TokenisedSentenceBatchType = List[TokenisedSentenceType]
SentenceBatchType = Union[List[str], TokenisedSentenceBatchType]


def tokenised_sentence_type_check(sentence: SentenceType) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion lambeq/experimental/discocirc/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@

from lambeq import AtomicType
from lambeq.backend.grammar import Box, Diagram, Frame, Id, Spider, Ty
from lambeq.backend.pregroup_tree import PregroupTreeNode
from lambeq.core.utils import SentenceBatchType, SentenceType
from lambeq.experimental.discocirc import (CoreferenceResolver,
SpacyCoreferenceResolver,
TreeRewriter,
TreeRewriteRule)
from lambeq.text2diagram import BobcatParser, CCGParser, Reader
from lambeq.text2diagram.pregroup_tree import PregroupTreeNode


NOUN = AtomicType.NOUN
Expand Down
11 changes: 11 additions & 0 deletions lambeq/oncilla/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lambeq.oncilla.parser import (
BertForSentenceToTree,
prepare_parent_logits_mask,
SentenceToTreeBertConfig,
)

__all__ = [
'BertForSentenceToTree',
'prepare_parent_logits_mask',
'SentenceToTreeBertConfig',
]
Loading
Loading