Skip to content

Commit

Permalink
fix sub ordering for bpe
Browse files Browse the repository at this point in the history
  • Loading branch information
Chenghao Mou authored and Chenghao Mou committed Jul 10, 2019
1 parent dcd45e9 commit 1c709e4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
8 changes: 5 additions & 3 deletions elisa_dnt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
help="[Post]File path to the dnt conf file")
parser.add_argument('--output', type=str,
help="[Post]File path to the output file")

parser.add_argument('--ordered', action='store_true', dest='ordered', default=False,
help='Sub parameter, use markers orderly as how LI tokens appear; '
'suggest True for translation, False for bpe')
parser.add_argument('--src', type=str,
help='[Pre]File path to the source file')
parser.add_argument('--src_output', type=str,
Expand All @@ -42,7 +44,7 @@
options = generate_options()

if args.step == "post":
restore(args.dnt_src, args.dnt_ini, args.output, args.scheme)
restore(args.dnt_src, args.dnt_ini, args.output, args.scheme, ordered=args.ordered)
exit(0)

if args.visual:
Expand All @@ -65,7 +67,7 @@
path = args.src

split(args.src, args.src_output, args.ini_output, scheme=args.scheme,
ref=args.tgt if args.scheme == "sub" and args.pb_cross else "", rules=rules)
ref=args.tgt if args.scheme == "sub" and args.pb_cross else "", rules=rules, ordered=args.ordered)

if args.visual:
if args.tgt == "":
Expand Down
25 changes: 15 additions & 10 deletions elisa_dnt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def find(string: str, rules: dict, scheme='del') -> list:
return merged_matches


def mark(string: str, matches: list, scheme: str = "sub") -> tuple:
def mark(string: str, matches: list, scheme: str = "sub", ordered: bool = True) -> tuple:
global MARKERS
if scheme == "sub":

Expand All @@ -83,7 +83,8 @@ def mark(string: str, matches: list, scheme: str = "sub") -> tuple:
modification.append((text, key))

for value, key in sorted(modification, key=lambda x: (len(x[0]), x[0]), reverse=True):
string = string.replace(value, f"{key}")
if ordered: string = string.replace(value, f"{key}")
else: string = string.replace(value, MARKERS[0])

return string, [m[0] for m in modification], None

Expand Down Expand Up @@ -175,7 +176,7 @@ def colorize(match, text):
return res


def split(corpus_path, corpus_output, ini_output, scheme: str, ref: str, rules: dict):
def split(corpus_path, corpus_output, ini_output, scheme: str, ref: str, rules: dict, ordered: bool=True):
with open(corpus_path) as source, open(corpus_output, "w") as o_source, open(ini_output, "w") as o_source_ini:

if ref == "":
Expand All @@ -184,7 +185,7 @@ def split(corpus_path, corpus_output, ini_output, scheme: str, ref: str, rules:
total_sents += 1
src = src.strip('\n')
src_matches = find(src, rules, scheme)
src_after, src_mod, src_lead = mark(src, src_matches, scheme=scheme)
src_after, src_mod, src_lead = mark(src, src_matches, scheme=scheme, ordered=ordered)
if scheme == "del":
for seg in src_after:
o_source.write(seg + "\n")
Expand Down Expand Up @@ -224,7 +225,7 @@ def split(corpus_path, corpus_output, ini_output, scheme: str, ref: str, rules:
x_matches = list(set(src_matches_text).intersection(set(tgt_matches_text)))
x_src_matches = [m for m in src_matches if src_line[m.start(0):m.end(0)] in x_matches]

src_after, src_mod, src_lead = mark(src_line, x_src_matches, scheme=scheme)
src_after, src_mod, src_lead = mark(src_line, x_src_matches, scheme=scheme, ordered=ordered)

o_source.write(src_after + "\n")

Expand All @@ -234,7 +235,7 @@ def split(corpus_path, corpus_output, ini_output, scheme: str, ref: str, rules:
o_source_ini.write("IGNORE\n")


def restore(dnt_path, ini_path, output, scheme="del"):
def restore(dnt_path, ini_path, output, scheme="del", ordered:bool=True):
global MARKERS

with open(output, "w") as o, open(dnt_path) as i_source, open(ini_path) as i_source_ini:
Expand Down Expand Up @@ -294,12 +295,16 @@ def restore(dnt_path, ini_path, output, scheme="del"):
new_translation = translation
for char in translation:
if char in MARKERS:
if ord(char) - 0x4DC0 >= len(segments):
if ord(char) - 0x4DC0 >= len(segments) or segments == []:
warnings.warn("Wired source sentence: {}".format(translation), Warning)
warnings.warn(" ".join(segments), Warning)
continue
new_translation = new_translation.replace(char,
segments[min(ord(char) - 0x4DC0, len(segments) - 1)])
if ordered:
new_translation = new_translation.replace(char, segments[min(ord(char) - 0x4DC0, len(segments) - 1)])
else:
new_translation = new_translation.replace(char, segments[0], 1)
segments.pop(0)

o.write(new_translation + '\n')


Expand All @@ -314,5 +319,5 @@ def restore(dnt_path, ini_path, output, scheme="del"):
spans = [txt[m.start:m.end] for m in matches]

print(spans)
print(mark(txt, find(txt, rules, 'sub'), 'sub'))
print(mark(txt, find(txt, rules, 'sub'), 'sub', ordered=False))
print(visual(txt, matches, options, rules))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='elisa-dnt',
version='0.1.5',
version='0.1.6',
packages=['elisa_dnt'],
url='https://github.com/ChenghaoMou/elisa-dnt',
license='',
Expand Down

0 comments on commit 1c709e4

Please sign in to comment.