|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +This script is used to 'translate' MSA sentences opus100 to dialectal Arabic using the AraT5v2-MSA-Dialect model. |
| 4 | +The translated sentences are then saved to a json file and pushed to the Hugging Face Hub. |
| 5 | +""" |
| 6 | +import pandas as pd |
| 7 | +import torch |
| 8 | +from datasets import load_dataset |
| 9 | +from torch.utils.data import DataLoader |
| 10 | +from tqdm import tqdm |
| 11 | +from transformers import ( |
| 12 | + AutoModelForSeq2SeqLM, |
| 13 | + T5Tokenizer, |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +HF_TOKEN = "" |
| 18 | + |
| 19 | + |
| 20 | +def remove_parenthesis(example): |
| 21 | + if "(" in example["msa"] and ")" in example["msa"] and "(" in example["dialect"] and ")" in example["dialect"]: |
| 22 | + return example |
| 23 | + example["msa"] = example["msa"].replace("(", "").replace(")", "") |
| 24 | + return example |
| 25 | + |
| 26 | + |
| 27 | +model_name = "Murhaf/AraT5v2-MSA-Dialect" |
| 28 | +tokenizer = T5Tokenizer.from_pretrained(model_name, token=HF_TOKEN) |
| 29 | +model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN) |
| 30 | + |
| 31 | +dataset = load_dataset("opus100", "ar-en", split="train") |
| 32 | + |
| 33 | +dataset = dataset.filter(lambda example: 5 < len(example["translation"]["ar"]) < 450) |
| 34 | +dataset = dataset.filter(lambda x: "{" not in x["translation"]["ar"]) |
| 35 | + |
| 36 | +ds = DataLoader(dataset, batch_size=256, shuffle=False) |
| 37 | + |
| 38 | +# Check if GPU is available, and move model and tokenizer to GPU |
| 39 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 40 | +model.to(device) |
| 41 | + |
| 42 | +prefix = "ترجمة للعامية: " |
| 43 | +dialect = [] |
| 44 | + |
| 45 | +# Batch size for processing |
| 46 | +batch_size = 64 |
| 47 | + |
| 48 | +# Iterate over the sentences in batches |
| 49 | +for batch in tqdm(ds): |
| 50 | + batch_sentences = batch["translation"]['ar'] |
| 51 | + |
| 52 | + # Tokenize batch of sentences |
| 53 | + inputs = tokenizer( |
| 54 | + [prefix + sent for sent in batch_sentences], |
| 55 | + return_tensors="pt", |
| 56 | + padding=True, |
| 57 | + truncation=True, |
| 58 | + ) |
| 59 | + |
| 60 | + # Move inputs to appropriate device |
| 61 | + inputs = {key: val.to(device) for key, val in inputs.items()} |
| 62 | + |
| 63 | + # Generate outputs |
| 64 | + with torch.no_grad(): |
| 65 | + outputs = model.generate(**inputs, max_new_tokens=60) |
| 66 | + |
| 67 | + # Decode and store outputs |
| 68 | + decoded_outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
| 69 | + dialect.extend(decoded_outputs) |
| 70 | + |
| 71 | +data = dataset.to_dict() |
| 72 | + |
| 73 | +# Update dataframe once at the end |
| 74 | +df = pd.DataFrame(data["translation"]) |
| 75 | +df["dialect"] = dialect |
| 76 | + |
| 77 | +df.to_json( |
| 78 | + "dialectal_opus.json", |
| 79 | + orient="records", |
| 80 | + lines=True, |
| 81 | + force_ascii=False, |
| 82 | +) |
| 83 | + |
| 84 | +dataset = load_dataset("json", data_files="dialectal_opus.json") |
| 85 | +dataset = dataset.map(remove_parenthesis) |
| 86 | +dataset.push_to_hub("opus100_msa_dialect_silver", private=True, token=HF_TOKEN) |
0 commit comments