Skip to content

Commit a42ac67

Browse files
authored
✨ OPUS MSA-into-dialect script
1 parent d1ab7a6 commit a42ac67

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
This script is used to 'translate' MSA sentences opus100 to dialectal Arabic using the AraT5v2-MSA-Dialect model.
4+
The translated sentences are then saved to a json file and pushed to the Hugging Face Hub.
5+
"""
6+
import pandas as pd
7+
import torch
8+
from datasets import load_dataset
9+
from torch.utils.data import DataLoader
10+
from tqdm import tqdm
11+
from transformers import (
12+
AutoModelForSeq2SeqLM,
13+
T5Tokenizer,
14+
)
15+
16+
17+
HF_TOKEN = ""
18+
19+
20+
def remove_parenthesis(example):
21+
if "(" in example["msa"] and ")" in example["msa"] and "(" in example["dialect"] and ")" in example["dialect"]:
22+
return example
23+
example["msa"] = example["msa"].replace("(", "").replace(")", "")
24+
return example
25+
26+
27+
model_name = "Murhaf/AraT5v2-MSA-Dialect"
28+
tokenizer = T5Tokenizer.from_pretrained(model_name, token=HF_TOKEN)
29+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
30+
31+
dataset = load_dataset("opus100", "ar-en", split="train")
32+
33+
dataset = dataset.filter(lambda example: 5 < len(example["translation"]["ar"]) < 450)
34+
dataset = dataset.filter(lambda x: "{" not in x["translation"]["ar"])
35+
36+
ds = DataLoader(dataset, batch_size=256, shuffle=False)
37+
38+
# Check if GPU is available, and move model and tokenizer to GPU
39+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40+
model.to(device)
41+
42+
prefix = "ترجمة للعامية: "
43+
dialect = []
44+
45+
# Batch size for processing
46+
batch_size = 64
47+
48+
# Iterate over the sentences in batches
49+
for batch in tqdm(ds):
50+
batch_sentences = batch["translation"]['ar']
51+
52+
# Tokenize batch of sentences
53+
inputs = tokenizer(
54+
[prefix + sent for sent in batch_sentences],
55+
return_tensors="pt",
56+
padding=True,
57+
truncation=True,
58+
)
59+
60+
# Move inputs to appropriate device
61+
inputs = {key: val.to(device) for key, val in inputs.items()}
62+
63+
# Generate outputs
64+
with torch.no_grad():
65+
outputs = model.generate(**inputs, max_new_tokens=60)
66+
67+
# Decode and store outputs
68+
decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
69+
dialect.extend(decoded_outputs)
70+
71+
data = dataset.to_dict()
72+
73+
# Update dataframe once at the end
74+
df = pd.DataFrame(data["translation"])
75+
df["dialect"] = dialect
76+
77+
df.to_json(
78+
"dialectal_opus.json",
79+
orient="records",
80+
lines=True,
81+
force_ascii=False,
82+
)
83+
84+
dataset = load_dataset("json", data_files="dialectal_opus.json")
85+
dataset = dataset.map(remove_parenthesis)
86+
dataset.push_to_hub("opus100_msa_dialect_silver", private=True, token=HF_TOKEN)

0 commit comments

Comments
 (0)