From bd24cb73b6ab977d550d71793d467c3069779be0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Grobol?= Date: Thu, 26 Mar 2020 22:43:45 +0100 Subject: [PATCH] fix failure in mlstm introduced by #17 And add a smoke test for graph-based to avoid this in the future --- tox.ini | 5 +++-- uuparser/mstlstm.py | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tox.ini b/tox.ini index 0f81385..37a8c91 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,9 @@ [tox] minversion = 3.4.0 -envlist = py38 +envlist = py38, py37 [testenv] deps = cython commands = - uuparser --dynet-seed 123456789 --outdir {envtmpdir}/smoketestoutput --trainfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu --devfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu --testfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu \ No newline at end of file + uuparser --dynet-seed 123456789 --epochs 3 --outdir {envtmpdir}/transition-smoketest-output --trainfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu --devfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu --testfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu + uuparser --dynet-seed 123456789 --epochs 3 --graph-based --outdir {envtmpdir}/graph-smoketest-output --trainfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu --devfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu --testfile test/fixtures/truncated-sv_talbanken-ud-dev.conllu diff --git a/uuparser/mstlstm.py b/uuparser/mstlstm.py index 9e9cbaf..d2614e0 100644 --- a/uuparser/mstlstm.py +++ b/uuparser/mstlstm.py @@ -1,17 +1,19 @@ from operator import itemgetter -import time, random, decoder -from chuliu_edmonds import chuliu_edmonds_one_root +import time, random import numpy as np -from multilayer_perceptron import biMLP from collections import defaultdict from copy import deepcopy -from uuparser import utils +from loguru import logger + +from uuparser import utils, decoder +from uuparser.chuliu_edmonds import chuliu_edmonds_one_root +from uuparser.multilayer_perceptron import biMLP class MSTParserLSTM: def __init__(self, vocab, options): import dynet as dy - from feature_extractor import FeatureExtractor + from uuparser.feature_extractor import FeatureExtractor global dy self.model = dy.ParameterCollection() self.trainer = dy.AdamTrainer(self.model, alpha=options.learning_rate)