diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml new file mode 100644 index 0000000..1376460 --- /dev/null +++ b/.github/workflows/ci-build.yml @@ -0,0 +1,35 @@ +name: Build unstable + +on: [push] + +concurrency: + group: unstable +# cancel-in-progress: true + + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.9 + uses: actions/setup-python@v4 + with: + python-version: "3.9" + - name: Cleanup more disk space + run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install --upgrade flake8 pytest pycodestyle + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# - name: Test with pytest +# run: | +# python -m pytest --rootdir . \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3bf780b --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +.env \ No newline at end of file diff --git a/README.md b/README.md index 2f9cb51..72fe7b7 100644 --- a/README.md +++ b/README.md @@ -3,26 +3,41 @@ A python tool for downloading & processing the [stackexchange data dumps](https: Download the whole processed dataset [here](https://eaidata.bmk.sh/data/stackexchange_dataset.tar) -# Setup +## Setup ``` git clone https://github.com/EleutherAI/stackexchange_dataset/ cd stackexchange_dataset pip install -r requirements.txt ``` -# Usage +## Usage -To download *every* stackexchange dump & parse to text, simply run + +### List all available StackExchagne dumps + +``` +python3 main.py --list +``` + + + +### Download every StackExchange dumps + +To download *every* stackexchange dumps & parse to text, simply run ``` python3 main.py --names all ``` +### Download a single StackExchange dump + To download only a single stackexchange, you can add the name as an optional argument. E.G: ``` python3 main.py --names security.stackexchange ``` +### Download a list of StackExchange dumps + To download a list of multiple stackexchanges, you can add the names separated by commas. E.G: ``` @@ -31,6 +46,17 @@ python3 main.py --names ru.stackoverflow,money.stackexchange The name should be the url of the stackoverflow site, minus `http(s)://` and `.com`. You can view all available stackoverflow dumps [here](https://archive.org/download/stackexchange). +## List available sources in Stack Exchange + +this will list all the available sources: + +``` +python3 main.py --list +``` + +They will be listed as list, which could be parsed with `grep` and other batch utilities. + + ## All Usage Options: ``` @@ -47,6 +73,19 @@ optional arguments: *every* stackoverflow site ``` +### Proxy support + +If you need to pass through a proxy, you can configure an `.env` file and add as follow: + +``` +HTTP_PROXY=http://proxy:port +http_proxy=http://proxy:port +HTTPS_PROXY=http://proxy:port +https_proxy=http://proxy:port +NO_PROXY=address to ignore,localhost +no_proxy=address to ignore,localhost +``` + # TODO: - [ ] should we add metadata to the text (i.e name of stackexchange & tags)? diff --git a/main.py b/main.py index 50a727a..a5b1c9a 100644 --- a/main.py +++ b/main.py @@ -1,56 +1,82 @@ -import argparse, traceback +import argparse +import os +import traceback +from itertools import repeat from multiprocessing import Pool, cpu_count -from utils import * + +import dotenv +from lm_dataformat import Archive, JSONArchive, TextArchive, LM_DATAFORMAT_FORMAT, TEXT_FORMAT, SUPPORTED_FORMATS, \ + JSON_FORMAT + from downloader import Stack_Exchange_Downloader from pairer import QA_Pairer -import os -from itertools import repeat -from lm_dataformat import Archive -import zipfile +dotenv.load_dotenv(override=True) -def download_and_process_single(name, out_format, min_score, max_responses): + +def download_and_process_single(name, out_format, min_score, max_responses, keep_sources=False): try: name = name.strip().lower() os.makedirs("dumps", exist_ok=True) s = Stack_Exchange_Downloader(name) + if name not in s.sites: + similar_entries = list(filter(lambda key: key.startswith(name) or key.endswith(name), s.sites.keys())) + print("StackExchange source not found. Perhaps you meant", similar_entries) + return + path_to_xml = "dumps/{}/Posts.xml".format(name) if name != "stackoverflow": path_to_7z = "dumps/{}.7z".format(s.sites[name]["url"]) else: path_to_7z = "dumps/stackoverflow.com-Posts.7z" - out_folder = "out".format(name) + + out_folder = "out/{}".format(name) os.makedirs(out_folder, exist_ok=True) if not os.path.isfile(path_to_7z): # download 7z if it's not downloaded already s.download() + if not os.path.isfile(path_to_xml): # extract 7z if it's not extracted already s.extract() - if out_format == "lm_dataformat": + + if out_format == LM_DATAFORMAT_FORMAT: archiver = Archive(out_folder) - elif out_format == "zip": - archiver = zipfile.ZipFile('{}/{}.zip'.format(out_folder, name), 'a') + elif out_format == TEXT_FORMAT: + archiver = TextArchive(out_folder) + elif out_format == JSON_FORMAT: + archiver = JSONArchive(out_folder) else: archiver = None - qa = QA_Pairer(path_to_xml, name=name, out_format=out_format, archiver=archiver, min_score=min_score, max_responses=max_responses) - qa.main() - if out_format == "lm_dataformat": - archiver.commit(name) - elif out_format == "zip": - archiver.close() - try: - os.remove(path_to_7z) - except FileNotFoundError: - print('ERROR: FileNotFoundError: File {} not found'.format(s.sites[name]["url"])) - filelist = [f for f in os.listdir("dumps/{}".format(name)) if f.endswith(".xml")] + + qa = QA_Pairer(path_to_xml, name=name, out_format=out_format, archiver=archiver, min_score=min_score, + max_responses=max_responses) + qa.process() + archiver.commit(name) + + if not keep_sources: + try: + os.remove(path_to_7z) + except FileNotFoundError: + print('ERROR: FileNotFoundError: File {} not found'.format(s.sites[name]["url"])) + + directory_uncompressed = "dumps/{}".format(name) + filelist = [f for f in os.listdir(directory_uncompressed) + if f.endswith(".xml")] for f in filelist: - os.remove(os.path.join("dumps/{}".format(name), f)) + os.remove(os.path.join(directory_uncompressed, f)) + os.removedirs(directory_uncompressed) except: traceback.print_exc() def main(args): + if args.list: + s = Stack_Exchange_Downloader("all") + print("List of all the sources of StackExchange: ") + print("- " + "\n- ".join(sorted(s.sites.keys()))) + return + names = args.names.split(',') if names[0].strip().lower() == "all": s = Stack_Exchange_Downloader("all") @@ -60,31 +86,51 @@ def main(args): # bring stackoverflow to the front so it is always processed first, since it's the largest if "stackoverflow" in names: names.insert(0, names.pop(names.index("stackoverflow"))) + # if args.no_zip: + # print("Downloading everything required the output to be compressed. Re-run *without* the option --no-zip.") + # sys.exit(-1) print('Downloading and processing stackexchange dumps for {}'.format(names)) # Download & Process # init pool with as many CPUs as available cpu_no = cpu_count() - 1 p = Pool(cpu_no) - p.starmap(download_and_process_single, zip(names, repeat(args.out_format), repeat(args.min_score), repeat(args.max_responses))) + p.starmap(download_and_process_single, + zip(names, repeat(args.out_format), repeat(args.min_score), repeat(args.max_responses), + repeat(args.keep_sources))) if __name__ == "__main__": parser = argparse.ArgumentParser( description='CLI for stackexchange_dataset - A tool for downloading & processing stackexchange dumps in xml form to a raw ' 'question-answer pair text dataset for Language Models') + + parser.add_argument('--list', help='list of all the sources from stackechange', + required=False, action="store_true") + parser.add_argument('--names', help='names of stackexchanges to download, extract & parse, separated by commas. ' 'If "all", will download, extract & parse *every* stackoverflow site', default="3dprinting.stackexchange,3dprinting.meta.stackexchange", type=str) - parser.add_argument('--out_format', help='format of out file - if you are processing everything this will need to be ' - 'lm_dataformat, as you will run into number of files per directory limits.', - default="zip", + parser.add_argument('--out_format', + help='format of out file - if you are processing everything this will need to be ' + 'lm_dataformat, as you will run into number of files per directory limits.', + default=TEXT_FORMAT, + choices=SUPPORTED_FORMATS, type=str) - parser.add_argument('--min_score', help='minimum score of a response in order to be included in the dataset. Default 3.', + # parser.add_argument('--no-zip', + # help="Disable the compression of the output files. Writing plain files might end up in problems with the filesystem", + # action="store_true", + # required=False, + # default=False) + parser.add_argument('--min_score', + help='minimum score of a response in order to be included in the dataset. Default 3.', type=int, default=3) - parser.add_argument('--max_responses', help='maximum number of responses (sorted by score) to include for each question. ' - 'Default 3.', type=int, default=3) + parser.add_argument('--max_responses', + help='maximum number of responses (sorted by score) to include for each question. ' + 'Default 3.', type=int, default=3) + parser.add_argument('--keep-sources', + help='Do not clean-up the downloaded source 7z files.', + action="store_true", default=False) args = parser.parse_args() - main(args) - + main(args) diff --git a/pairer.py b/pairer.py index 880bee7..838e319 100644 --- a/pairer.py +++ b/pairer.py @@ -1,14 +1,18 @@ import traceback import xml.etree.ElementTree as etree from collections import defaultdict + from bs4 import BeautifulSoup +from lm_dataformat import SUPPORTED_FORMATS, LM_DATAFORMAT_FORMAT, JSON_FORMAT, TEXT_FORMAT, TextArchive from tqdm import tqdm + from utils import * class QA_Pairer(): - def __init__(self, xml_path, name=None, out_folder="out", min_score=3, max_responses=3, out_format="txt", archiver=None): + def __init__(self, xml_path, name=None, out_folder="out", min_score=3, max_responses=3, out_format=TEXT_FORMAT, + archiver=None): """Makes a text dataset from StackExchange dumps""" self.xml_path = xml_path if name is None: @@ -22,14 +26,14 @@ def __init__(self, xml_path, name=None, out_folder="out", min_score=3, max_respo # min_score required to parse an answer self.min_score = min_score self.max_responses = max_responses - assert out_format in ["txt", "lm_dataformat", "zip"], "Out format not recognized" + assert out_format in SUPPORTED_FORMATS, "Out format not recognized" self.out_format = out_format - if out_format in ["lm_dataformat", "zip"]: + if out_format in SUPPORTED_FORMATS: assert archiver is not None self.ar = archiver - def main(self): - """iterates through SE xmls and: + def process(self): + """iterates through SE XMLs and: - stores PostTypeId="1" with AcceptedAnswerIds / Answers. - when an AcceptedAnswerId or Answer > min_score is reached, it should: @@ -40,7 +44,8 @@ def main(self): """ os.makedirs(self.out_folder, exist_ok=True) - for event, elem in tqdm(etree.iterparse(self.xml_path, events=('end',)), desc="Parsing {} XML file".format(self.name)): + for event, elem in tqdm(etree.iterparse(self.xml_path, events=('end',)), + desc="Parsing {} XML file".format(self.name)): if elem.tag == "row": try: attribs = defaultdict(lambda: None, elem.attrib) @@ -94,7 +99,8 @@ def add_answer(self, a_attribs): if a_attribs["Id"] is not None: parent = self.questions[a_attribs["ParentId"]] if parent is not None: - self.questions[a_attribs["ParentId"]]["Answers"][a_attribs["Id"]] = trim_attribs(a_attribs, "answer") + self.questions[a_attribs["ParentId"]]["Answers"][a_attribs["Id"]] = trim_attribs(a_attribs, + "answer") self.questions[a_attribs["ParentId"]]["ParsedAnswers"] += 1 else: self.questions[a_attribs["ParentId"]]["ParsedAnswers"] += 1 @@ -107,6 +113,13 @@ def check_complete(self, a_attribs): removes from dict and prints to file. """ keys_to_del = [] + qa_structure = { + "question": { + "title": "", + "body": "" + }, + "answers": [] + } parent = self.questions[a_attribs["ParentId"]] if a_attribs is not None and parent is not None: if parent["AnswerCount"] is not None and parent["ParsedAnswers"] is not None: @@ -114,40 +127,30 @@ def check_complete(self, a_attribs): keys_to_del.append(a_attribs["ParentId"]) if parent["Answers"] is not None and len(parent["Answers"]) > 0: out_name = "{}_{}.txt".format(self.name, parent["Id"].zfill(10)) - out_str = "" - out_str += 'Q:\n\n' + question_structure = qa_structure['question'] if parent["Title"] is not None: - out_str += '{}\n\n'.format(BeautifulSoup(parent["Title"], "html.parser").get_text()) + question_structure['title'] = parent["Title"] if parent["Body"] is not None: - out_str += '{}\n\n'.format(BeautifulSoup(parent["Body"], "html.parser").get_text()) + question_structure['body'] = BeautifulSoup(parent["Body"], "html.parser").get_text() if parent["Answers"] is not None: key_score_dict = {} + answers_structure_tmp = [] for k, a in parent["Answers"].items(): - key_score_dict[k] = int(a["Score"]) - key_score_dict = {k: v for k, v in sorted(key_score_dict.items(), key=lambda item: item[1], reverse=True)} - count = 0 - for k in key_score_dict: - if count >= self.max_responses: - break - out_str += 'A:\n\n{}\n\n'.format(BeautifulSoup(parent["Answers"][k]["Body"], "html.parser").get_text()) - count += 1 - if self.out_format == "txt": - with open("{}/{}".format(self.out_folder, out_name), 'w') as f: - try: - f.write(filter_newlines(out_str)) - except: - f.write(filter_newlines(handle_unicode_errors(out_str))) - elif self.out_format == "zip": - try: - self.ar.writestr(out_name, filter_newlines(out_str)) - except: - self.ar.writestr(out_name, filter_newlines(handle_unicode_errors(out_str))) - elif self.out_format == "lm_dataformat": - try: - self.ar.add_data(filter_newlines(out_str), meta={ - 'name': out_name}) - except: - self.ar.add_data(filter_newlines(handle_unicode_errors(out_str)), meta={ - 'name': out_name}) + # key_score_dict[k] = int(a["Score"]) + answers_structure_tmp.append({ + "id": a['Id'], + "body": BeautifulSoup(a["Body"], "html.parser").get_text(), + "score": int(a["Score"]) + }) + qa_structure['answers'] = sorted(answers_structure_tmp, key=lambda item: item['score'], + reverse=True)[0:self.max_responses] + + if self.out_format == TEXT_FORMAT: + self.ar.add_data(TextArchive.to_text(qa_structure)) + elif self.out_format == JSON_FORMAT: + self.ar.add_data(qa_structure) + elif self.out_format == LM_DATAFORMAT_FORMAT: + self.ar.add_data(TextArchive.to_text(qa_structure), meta={'name': out_name}) + for key in keys_to_del: self.questions.pop(key, None) diff --git a/requirements.txt b/requirements.txt index 9d3cf3e..cfb9ed1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ bs4 lxml py7zr tqdm -lm-dataformat +# lm-dataformat +-e git+https://github.com/lfoppiano/lm_dataformat.git#egg=lm_dataformat jsonlines \ No newline at end of file