From b2ab0c9b69d89eb53bc1114a459848c35ba94d57 Mon Sep 17 00:00:00 2001 From: Ralph Tang Date: Tue, 21 Apr 2020 13:03:31 -0400 Subject: [PATCH] Implement highlighting rerankers (#1) * Implement highlighting rerankers - Add T5, transformer, and BM25 rerankers - Add Kaggle dataset and evaluation framework * Fix README instructions - Add missing activation command * Fix BM25 bug - IDF not computed correctly * Improve IDF computation for BM25 reranker - Add option to compute IDF statistics from corpus * Add LongBatchEncoder documentation --- README.md | 19 +- data/kaggle-lit-review.json | 459 ++++++++++++++++++++ environment.yml | 108 +++++ pygaggle/__init__.py | 2 +- pygaggle/data/__init__.py | 2 + pygaggle/data/kaggle.py | 69 +++ pygaggle/data/relevance.py | 34 ++ pygaggle/lib/IdentityReranker.py | 14 - pygaggle/lib/__init__.py | 9 - pygaggle/logger.py | 7 + pygaggle/model/__init__.py | 5 + pygaggle/model/decode.py | 29 ++ pygaggle/model/encode.py | 96 ++++ pygaggle/model/evaluate.py | 103 +++++ pygaggle/model/serialize.py | 72 +++ pygaggle/model/tokenize.py | 132 ++++++ pygaggle/rerank/__init__.py | 4 + pygaggle/{rerank.py => rerank/base.py} | 33 +- pygaggle/rerank/bm25.py | 49 +++ pygaggle/rerank/similarity.py | 28 ++ pygaggle/rerank/transformer.py | 75 ++++ pygaggle/run/__init__.py | 1 + pygaggle/run/args.py | 41 ++ pygaggle/run/evaluate_kaggle_highlighter.py | 95 ++++ pygaggle/settings.py | 17 + scripts/evaluate-highlighters.sh | 7 + scripts/update-index.sh | 14 + tests/test_base.py | 2 +- 28 files changed, 1486 insertions(+), 40 deletions(-) create mode 100644 data/kaggle-lit-review.json create mode 100644 environment.yml create mode 100644 pygaggle/data/__init__.py create mode 100644 pygaggle/data/kaggle.py create mode 100644 pygaggle/data/relevance.py delete mode 100644 pygaggle/lib/IdentityReranker.py delete mode 100644 pygaggle/lib/__init__.py create mode 100644 pygaggle/logger.py create mode 100644 pygaggle/model/__init__.py create mode 100644 pygaggle/model/decode.py create mode 100644 pygaggle/model/encode.py create mode 100644 pygaggle/model/evaluate.py create mode 100644 pygaggle/model/serialize.py create mode 100644 pygaggle/model/tokenize.py create mode 100644 pygaggle/rerank/__init__.py rename pygaggle/{rerank.py => rerank/base.py} (73%) create mode 100644 pygaggle/rerank/bm25.py create mode 100644 pygaggle/rerank/similarity.py create mode 100644 pygaggle/rerank/transformer.py create mode 100644 pygaggle/run/__init__.py create mode 100644 pygaggle/run/args.py create mode 100644 pygaggle/run/evaluate_kaggle_highlighter.py create mode 100644 pygaggle/settings.py create mode 100644 scripts/evaluate-highlighters.sh create mode 100644 scripts/update-index.sh diff --git a/README.md b/README.md index eaa854c9..a46be9c7 100644 --- a/README.md +++ b/README.md @@ -1 +1,18 @@ -# pygaggle \ No newline at end of file +# PyGaggle + +A gaggle of CORD-19 rerankers. + +## Installation + +1. `conda env create -f environment.yml && conda activate pygaggle` + +2. Install [PyTorch 1.4+](http://pytorch.org/). + +3. Download the index: `sh scripts/update-index.sh` + +4. Make sure you have an installation of Java 8+: `javac --version` + + +## Evaluating Highlighters + +Run `sh scripts/evaluate-highlighters.sh`. \ No newline at end of file diff --git a/data/kaggle-lit-review.json b/data/kaggle-lit-review.json new file mode 100644 index 00000000..a1fafc5c --- /dev/null +++ b/data/kaggle-lit-review.json @@ -0,0 +1,459 @@ +{ + "categories":[ + { + "name":"Incubation period", + "sub_categories":[ + { + "name":"What is the incubation period of the virus?", + "answers":[ + { + "id":"wuclekt6", + "title":"Longitudinal analysis of laboratory findings during the process of recovery for patients with COVID-19", + "exact_answer":"4 days (IQR, 2-7)" + }, + { + "id":"e3t1f0rt", + "title":"Epidemiological Characteristics of COVID-19; a Systemic Review and Meta-Analysis 1", + "exact_answer":"5.84 (99% CI: 4.83, 6.85) days" + }, + { + "id":"ragcpbl6", + "title":"Evolving epidemiology of novel coronavirus diseases 2019 and possible interruption of local transmission outside Hubei Province in China: a descriptive and modeling study", + "exact_answer":"5·2 days" + }, + { + "id":"n0uwy77g", + "title":"Clinical characteristics and durations of hospitalized patients with COVID-19 in Beijing: a retrospective cohort study", + "exact_answer":"4 (3-7) days" + }, + { + "id":"", + "title":"Early Transmission Dynamics in Wuhan, China, of Novel Coronavirus–Infected Pneumonia", + "exact_answer":"5.2 days (95% confidence interval [CI], 4.1 to 7.0)" + }, + { + "id":"x23ej29m", + "title":"Clinical features and obstetric and neonatal outcomes of pregnant patients with COVID-19 in Wuhan, China: a retrospective, single-centre, descriptive study", + "exact_answer":"5 days (range 2-9 days)" + }, + { + "id":"56zhxd6e", + "title":"Epidemiological parameters of coronavirus disease 2019: a pooled analysis of publicly reported individual data of 1155 cases from seven countries Summary Background", + "exact_answer":"7.44 days" + }, + { + "id":"zph6r4il", + "title":"Epidemiological, clinical and virological characteristics of 74 cases of coronavirus-infected disease 2019 (COVID-19) with gastrointestinal symptoms", + "exact_answer":"4 days (IQR 3-7 days)" + }, + { + "id":"djq0lvr2", + "title":"Is a 14-day quarantine period optimal for effectively controlling coronavirus disease 2019 (COVID-19)?", + "exact_answer":"The median incubation period of both male and female adults was similar (7-day) but significantly shorter than that (9-day) of child cases" + }, + { + "id":"n0vmb946", + "title":"The difference in the incubation period of 2019 novel coronavirus (SARS-CoV-2) infection between travelers to Hubei and non-travelers: The need of a longer quarantine period", + "exact_answer":"1.8 and 7.2 days" + }, + { + "id":"awgyxn3t", + "title":"Clinical Characteristics of 34 Children with Coronavirus Disease-2019 in the West of China: a Multiple-center Case Series", + "exact_answer":"10.50 (7.75 -25.25) days" + }, + { + "id":"0hnh4n9e", + "title":"Investigation of three clusters of COVID-19 in Singapore: implications for surveillance and response measures", + "exact_answer":"4 days (IQR 3-6)" + }, + { + "id":"it4ka7v0", + "title":"Estimation of incubation period distribution of COVID-19 using disease onset forward time: a novel cross-sectional and forward follow-up study", + "exact_answer":"median of incubation period is 8·13 days (95% confidence interval [CI]: 7·37-8·91), the mean is 8·62 days (95% CI: 8·02-9·28)" + }, + { + "id":"glq0lckz", + "title":"Clinical Characteristics of SARS-CoV-2 Pneumonia Compared to Controls in Chinese Han Population", + "exact_answer":"4 days (IQR, 2 to 7)" + }, + { + "id":"8anqfkmo", + "title":"The Incubation Period of Coronavirus Disease 2019 (COVID-19) From Publicly Reported Confirmed Cases: Estimation and Application", + "exact_answer":"5.1 days (CI, 4.5 to 5.8 days)" + }, + { + "id":"v3gww4iv", + "title":"Transmission of corona virus disease 2019 during the incubation period may lead to a quarantine loophole", + "exact_answer":"4.9 days (95% confidence interval [CI], 4.4 to 5.4) days" + }, + { + "id":"66ulqu11", + "title":"Transmission interval estimates suggest pre-symptomatic spread of COVID-19", + "exact_answer":"7.1 (6.13, 8.25) days for Singapore and 9 (7.92, 10.2) days for Tianjin" + }, + { + "id":"ti9b1etu", + "title":"Transmission and clinical characteristics of coronavirus disease 2019 in 104 outside-Wuhan patients, China", + "exact_answer":"6 days, ranged from 1 to 32 days" + }, + { + "id":"k3f7ohzg", + "title":"Characteristics of COVID-19 infection in Beijing", + "exact_answer":"6.7 days" + }, + { + "id":"jxtch47t", + "title":"Epidemiologic and Clinical Characteristics of 91 Hospitalized Patients with COVID-19 in Zhejiang, China: A retrospective, multi-centre case series", + "exact_answer":"6 (IQR, 3-8) days" + }, + { + "id":"dbzrd23n", + "title":"Title: A descriptive study of the impact of diseases control and prevention on the epidemics 1 dynamics and clinical features of SARS-CoV-2 outbreak in Shanghai, lessons learned for", + "exact_answer":"6.4 days (95% 175 CI 5.3 to 7.6)" + }, + { + "id":"j3avpu1y", + "title":"A familial cluster of pneumonia associated with the 2019 novel coronavirus indicating person-to-person transmission: a study of a family cluster", + "exact_answer":"3-6 days" + }, + { + "id":"1mxjklgx", + "title":"Epidemiological characteristics of 1212 COVID-19 patients in Henan, China. medRxiv", + "exact_answer":"average, mode and median incubation periods are 7.4, 4 and 7 days" + }, + { + "id":"ykofrn9i", + "title":"Incubation Period and Other Epidemiological Characteristics of 2019 Novel Coronavirus Infections with Right Truncation: A Statistical Analysis of Publicly Available Case Data", + "exact_answer":"5.6 days (95% CI: 4.4, 7.4)" + }, + { + "id":"u8goc7io", + "title":"Title: The incubation period of 2019-nCoV infections among travellers from Wuhan, China", + "exact_answer":"6.4 (5.6 -7.7, 95% CI) days" + } + ] + }, + { + "name":"Length of viral shedding after illness onset", + "answers":[ + { + "id":"bg0cw5s6", + "title":"Factors associated with prolonged viral shedding and impact of Lopinavir/Ritonavir treatment in patients with SARS-CoV-2 infection", + "exact_answer":"23 days (IQR, 18-32 days)" + }, + { + "id":"r5a46n9a", + "title":"Viral Kinetics and Antibody Responses in Patients with COVID-19", + "exact_answer":"12 (3-38), 19 (5-37), and 18 (7-26) days in nasopharyngeal swabs, sputum and stools, respectively" + }, + { + "id":"k36rymkv", + "title":"Clinical course and risk factors for mortality of adult inpatients with COVID-19 in Wuhan, China: a retrospective cohort study", + "exact_answer":"20·0 days (IQR 17·0–24·0)" + } + ] + }, + { + "name":"Incubation period across different age groups", + "answers":[ + { + "id":"giabjjnz", + "title":"Children are unlikely to have been the primary source of household SARS-CoV-2 infections", + "exact_answer":"7.74 d ± 3.22" + }, + { + "id":"djq0lvr2", + "title":"Is a 14-day quarantine period optimal for effectively controlling coronavirus disease 2019 (COVID-19)?", + "exact_answer":"median incubation period of both male and female adults was similar (7-day) but significantly shorter than that (9-day) of child cases" + }, + { + "id":"awgyxn3t", + "title":"Clinical Characteristics of 34 Children with Coronavirus Disease-2019 in the West of China: a Multiple-center Case Series", + "exact_answer":"10.50 (7.75 -25.25) days" + } + ] + } + ] + }, + { + "name":"Asymptomatic shedding", + "sub_categories":[ + { + "name":"Proportion of patients who were asymptomatic", + "answers":[ + { + "id":"bmsmegbs", + "title":"A considerable proportion of individuals with asymptomatic SARS-CoV-2 infection in Tibetan population", + "exact_answer":"21.7%" + }, + { + "id":"jjgfgqwg", + "title":"Modes of contact and risk of transmission in COVID-19 among close contacts", + "exact_answer":"6.2%" + }, + { + "id":"7w1bhaz6", + "title":"High incidence of asymptomatic SARS-CoV-2 infection, Chongqing, China", + "exact_answer":"19%" + }, + { + "id":"xsqgrd5l", + "title":"High transmissibility of COVID-19 near symptom onset", + "exact_answer":"there were 32 laboratory-confirmed COVID-19 patients, including five household/family clusters and four asymptomatic patients" + }, + { + "id":"6su2x8mk", + "title":"Non-severe vs severe symptomatic COVID-19: 104 cases from the outbreak on the cruise ship “Diamond Princess” in Japan", + "exact_answer":"76 and 28 patients were classified as non-severe (asymptomatic, mild)" + }, + { + "id":"rjm1dqk7", + "title":"Epidemiological characteristics of 2019 novel coronavirus family clustering in Zhejiang Province", + "exact_answer":"54 asymptomatic infected cases" + }, + { + "id":"56zhxd6e", + "title":"Epidemiological parameters of coronavirus disease 2019: a pooled analysis of publicly reported individual data of 1155 cases from seven countries", + "exact_answer":"49 (14.89%) were asymptomatic" + }, + { + "id":"atnz63pk", + "title":"Estimating the Asymptomatic Proportion of 2019 Novel Coronavirus onboard the Princess Cruises Ship, 2020", + "exact_answer":"17.9%" + }, + { + "id":"ofoqk100", + "title":"Clinical Characteristics of 24 Asymptomatic Infections with COVID-19 Screened among Close Contacts in Nanjing, China", + "exact_answer":"The remaining 7 (29.2%) cases showed normal CT image and had no symptoms during hospitalization." + }, + { + "id":"k3f7ohzg", + "title":"Characteristics of COVID-19 infection in Beijing", + "exact_answer":"13 (5.0%) asymptomatic cases" + }, + { + "id":"f3h74j1n", + "title":"Estimation of the asymptomatic ratio of novel coronavirus infections (COVID-19)", + "exact_answer":"the asymptomatic ratio at 41.6%" + } + ] + }, + { + "name":"Proportion of pediatric patients who were asymptomatic", + "answers":[ + { + "id":"xsgxd5sy", + "title":"Epidemiological and Clinical Characteristics of Children with Coronavirus Disease 2019", + "exact_answer":"20 (27.03%) cases" + }, + { + "id":"dmrtsxik", + "title":"Articles Clinical and epidemiological features of 36 children with coronavirus disease 2019 (COVID-19) in Zhejiang, China: an observational cohort study", + "exact_answer":"ten (28%) patients" + }, + { + "id":"7w1bhaz6", + "title":"High incidence of asymptomatic SARS-CoV-2 infection, Chongqing, China", + "exact_answer":"(28.6%) in children group under 14, next in elder group over 70 (27.3%)" + }, + { + "id":"jvhrp51s", + "title":"Title: The clinical and epidemiological features and hints of 82 confirmed COVID-19 pediatric cases aged 0-16 in Wuhan, China", + "exact_answer":"8 (9.76%)" + }, + { + "id":"j58f1lwa", + "title":"Preliminary epidemiological analysis on children and adolescents with novel coronavirus disease 2019 outside Hubei Province, China: an observational study utilizing crowdsourced data", + "exact_answer":"2 (8.0%) cases" + }, + { + "id":"mar8zt2t", + "title":"The different clinical characteristics of corona virus disease cases between children and their families in China -the character of children with COVID-19", + "exact_answer":"six (66.7%) children" + } + ] + }, + { + "name":"Asymptomatic transmission during incubation", + "answers":[ + { + "id":"eflwztji", + "title":"Temporal dynamics in viral shedding and transmissibility of COVID-19", + "exact_answer":"44% of transmission prior to symptom onset" + }, + { + "id":"v3gww4iv", + "title":"Transmission of corona virus disease 2019 during the incubation period may lead to a quarantine loophole", + "exact_answer":"(73.0%) were infected before the symptom onset of the first-generation cases" + }, + { + "id":"56zhxd6e", + "title":"Epidemiological parameters of coronavirus disease 2019: a pooled analysis of publicly reported individual data of 1155 cases from seven countries Summary Background", + "exact_answer":"In 102 (43.78%) infector-infectee pairs, transmission occurred before infectors' symptom onsets" + }, + { + "id":"v3gww4iv", + "title":"Transmission of corona virus disease 2019 during the incubation period may lead to a quarantine loophole", + "exact_answer":"(73.0%) were infected before the symptom onset of the first-generation cases" + }, + { + "id":"st5vs6gq", + "title":"Title: The serial interval of COVID-19 from publicly reported confirmed cases Running Head: The serial interval of COVID-19", + "exact_answer":"12.6% of reports indicating pre-symptomatic transmission" + } + ] + } + ] + }, + { + "name":"Persistence of sources", + "sub_categories":[ + { + "name":"Length of viral shedding in stool", + "answers":[ + { + "id":"k86ljbxu", + "title":"Do children need a longer time to shed SARS- CoV-2 in stool than adults?", + "exact_answer":"100% positive on the third week after onset and 30% positive 29 days later" + }, + { + "id":"ouca1bol", + "title":"Evaluation of SARS-CoV-2 RNA shedding in clinical specimens and clinical characteristics of 10 patients with COVID-19 in Macau", + "exact_answer":"detected in feces till 14 days after the onset of symptoms" + }, + { + "id":"r5a46n9a", + "title":"Viral Kinetics and Antibody Responses in Patients with COVID-19", + "exact_answer":"18 days (range, 7-26)" + }, + { + "id":"1fyag5x3", + "title":"Virus shedding patterns in nasopharyngeal and fecal specimens of COVID-19 patients 2 3", + "exact_answer":"22.0 days (IQR 15.5 to 23.5)" + } + ] + }, + { + "name":"Length of viral shedding from nasopharynx", + "answers":[ + { + "id":"1fyag5x3", + "title":"Virus shedding patterns in nasopharyngeal and fecal specimens of COVID-19 patients 2 3", + "exact_answer":"10.0 days (IQR 8.0 to 17.0)" + }, + { + "id":"r5a46n9a", + "title":"Viral Kinetics and Antibody Responses in Patients with COVID-19", + "exact_answer":"12 days (range, 3-38 )" + } + ] + }, + { + "name":"Length of viral shedding in urine", + "answers":[ + { + "id":"1fyag5x3", + "title":"Virus shedding patterns in nasopharyngeal and fecal specimens of COVID-19 patients 2 3", + "exact_answer":"all patient data were and urine samples were all negative, except for urine samples from two 53 severe cases at the latest available detection point (16 or 21 d.a.o)" + } + ] + }, + { + "name":"Length of viral shedding in blood", + "answers":[ + { + "id":"ac4aesoa", + "title":"Comparisons of nucleic acid conversion time of SARS-CoV-2 of different samples in ICU and non-ICU patients Conversion time of SARS-CoV-2 RT-PCR in ICU and non-ICU patients Letter to the Editor Comparisons of nucleic acid conversion time of SARS-CoV-2 of different samples in ICU and non-ICU patients", + "exact_answer":"10.17 ± 6.134 and 14.63 ± 5.878 days in non-ICU and ICU patients respectively" + } + ] + }, + { + "name":"Prevalence of viral shedding in blood", + "answers":[ + { + "id":"r5a46n9a", + "title":"Viral Kinetics and Antibody Responses in Patients with COVID-19", + "exact_answer":"12 plasmas (5.7%) from 9 patients (14.3%) were positive" + } + ] + } + ] + }, + { + "name":"Diagnostics", + "sub_categories":[ + { + "name":"Sensitivity and specificity of COVID-19 tests", + "answers":[ + { + "id":"9p7hqk1u", + "title":"Journal Pre-proof COVID-19 pneumonia: a review of typical CT findings and differential diagnosis COVID-19 pneumonia: a review of typical CT findings and differential diagnosis", + "exact_answer":"A large series based on 1014 patients reported a 97% sensitivity of chest CT for the diagnosis of COVID-19, while the mean time interval between initial negative and positive RT-PCR was approximately 5 days" + }, + { + "id":"aa7slcnc", + "title":"Highly accurate and sensitive diagnostic detection of SARS-CoV-2 by digital PCR", + "exact_answer":"The overall sensitivity, specificity and diagnostic accuracy of RT-dPCR were 90%, 100% and 93 %, respectively" + }, + { + "id":"na8odvj7", + "title":"Serological detection of 2019-nCoV respond to the epidemic: A useful complement to nucleic acid testing", + "exact_answer":"The areas under the ROC curves of IgM and IgG were 0.988 and 1.000, respectively." + }, + { + "id":"", + "title":"Molecular immune pathogenesis and diagnosis of COVID-19", + "exact_answer":"RT-qPCR can only achieve a sensitivity of 50% to 79%, depending on the protocol used the sample type and number of clinical specimens collected" + }, + { + "id":"", + "title":"Molecular immune pathogenesis and diagnosis of COVID-19", + "exact_answer":"The sensitivity of SARS-CoV N-based IgG ELISA (94.7%) is significantly higher than that of SARS-CoV S-based IgG ELISA (58.9%)" + }, + { + "id":"py38bel4", + "title":"Clinical significance of IgM and IgG test for diagnosis of highly suspected COVID-19 infection", + "exact_answer":"The positive detection rate of combination of IgM and IgG for patients with COVID-19 negative and positive nucleic acid test was 72.73% and 87.50%." + }, + { + "id":"5skk3nj4", + "title":"Imaging manifestations and diagnostic value of chest CT of coronavirus disease 2019 (COVID-19) in the Xiaogan area", + "exact_answer":"the overall accuracy rate of CT examination in the present study was 97.3%" + }, + { + "id":"cv3qgno3", + "title":"Rapid Molecular Detection of SARS-CoV-2 (COVID-19) Virus RNA Using Colorimetric LAMP", + "exact_answer":"colorimetric LAMP assay showed 100% agreement with the RT-qPCR results across a range of C q values" + }, + { + "id":"", + "title":"Differential diagnosis of illness in patients under investigation for the novel coronavirus (SARS-CoV-2), Italy, February 2020", + "exact_answer":"Broad screening for respiratory pathogens revealed a high rate of influenza virus infections, accounting for 28.5% of all suspected cases of SARS-CoV-2 infection" + }, + { + "id":"8gncbgot", + "title":"Potential Rapid Diagnostics, Vaccine and Therapeutics for 2019 Novel Coronavirus (2019-nCoV): A Systematic Review", + "exact_answer":"E gene and RdRp gene assays produced the best result (5.2 and 3.8 copies per reaction at 95% detection probability, respectively)" + }, + { + "id":"chln5r8w", + "title":"Diagnosis of the Coronavirus disease (COVID-19): rRT-PCR or CT?", + "exact_answer":"Sensitivity of CT examinations was 97.2% at presentation, whereas first round rRT-PCR sensitivity was 84.6%" + }, + { + "id":"s7uqawbd", + "title":"Rapid colorimetric detection of COVID-19 coronavirus using a reverse tran- scriptional loop-mediated isothermal amplification (RT-LAMP) diagnostic plat- form: iLACO", + "exact_answer":"iLACO is very sensitive, and as low as 10 copies of ORF1ab gene were detected successfully" + }, + { + "id":"", + "title":"Detection of 2019 novel coronavirus (2019-nCoV) by real-time RT-PCR", + "exact_answer":"All assays were highly sensitive, with best results obtained for the E gene and RdRp gene assays (5.2 and 3.8 copies per reaction at 95% detection probability, respectively)" + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 00000000..1ffa08e4 --- /dev/null +++ b/environment.yml @@ -0,0 +1,108 @@ +name: pygaggle +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1 + - blas=1.0 + - ca-certificates=2020.1.1 + - certifi=2020.4.5.1 + - freetype=2.9.1 + - intel-openmp=2020.0 + - jpeg=9b + - libedit=3.1.20181209 + - libffi=3.2.1 + - libgcc-ng=9.1.0 + - libgfortran-ng=7.3.0 + - libpng=1.6.37 + - libstdcxx-ng=9.1.0 + - libtiff=4.1.0 + - mkl=2020.0 + - mkl-service=2.3.0 + - mkl_fft=1.0.15 + - mkl_random=1.1.0 + - ncurses=6.2 + - ninja=1.9.0 + - numpy-base=1.18.1 + - olefile=0.46 + - openssl=1.1.1f + - pillow=7.0.0 + - pip=20.0.2 + - python=3.7.7 + - readline=8.0 + - setuptools=46.1.3 + - six=1.14.0 + - sqlite=3.31.1 + - tk=8.6.8 + - wheel=0.34.2 + - xz=5.2.5 + - zlib=1.2.11 + - zstd=1.3.7 + - pip: + - absl-py==0.9.0 + - astor==0.8.1 + - blis==0.4.1 + - boto3==1.12.41 + - botocore==1.15.41 + - cachetools==4.1.0 + - catalogue==1.0.0 + - chardet==3.0.4 + - click==7.1.1 + - coloredlogs==14.0 + - cymem==2.0.3 + - cython==0.29.16 + - docutils==0.15.2 + - filelock==3.0.12 + - gast==0.2.2 + - google-auth==1.14.0 + - google-auth-oauthlib==0.4.1 + - google-pasta==0.2.0 + - grpcio==1.28.1 + - h5py==2.10.0 + - humanfriendly==8.2 + - idna==2.9 + - importlib-metadata==1.6.0 + - jmespath==0.9.5 + - joblib==0.14.1 + - keras-applications==1.0.8 + - keras-preprocessing==1.1.0 + - markdown==3.2.1 + - murmurhash==1.0.2 + - numpy==1.18.2 + - oauthlib==3.1.0 + - opt-einsum==3.2.1 + - plac==1.1.3 + - preshed==3.0.2 + - protobuf==3.11.3 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pydantic==1.5 + - pyjnius==1.2.1 + - pyserini==0.9.0.0 + - python-dateutil==2.8.1 + - regex==2020.4.4 + - requests==2.23.0 + - requests-oauthlib==1.3.0 + - rsa==4.0 + - s3transfer==0.3.3 + - sacremoses==0.0.41 + - scikit-learn==0.22.2.post1 + - scipy==1.4.1 + - sentencepiece==0.1.85 + - sklearn==0.0 + - spacy==2.2.4 + - srsly==1.0.2 + - tensorboard==2.1.1 + - tensorflow==2.1.0 + - tensorflow-estimator==2.1.0 + - tensorflow-gpu==2.1.0 + - tensorflow-text==2.1.1 + - termcolor==1.1.0 + - thinc==7.4.0 + - tokenizers==0.5.2 + - tqdm==4.45.0 + - transformers==2.7.0 + - urllib3==1.25.9 + - wasabi==0.6.0 + - werkzeug==1.0.1 + - wrapt==1.12.1 + - zipp==3.1.0 diff --git a/pygaggle/__init__.py b/pygaggle/__init__.py index 8b137891..4c98ddd1 100644 --- a/pygaggle/__init__.py +++ b/pygaggle/__init__.py @@ -1 +1 @@ - +from .logger import * diff --git a/pygaggle/data/__init__.py b/pygaggle/data/__init__.py new file mode 100644 index 00000000..8910717c --- /dev/null +++ b/pygaggle/data/__init__.py @@ -0,0 +1,2 @@ +from .kaggle import * +from .relevance import * diff --git a/pygaggle/data/kaggle.py b/pygaggle/data/kaggle.py new file mode 100644 index 00000000..46fbd17d --- /dev/null +++ b/pygaggle/data/kaggle.py @@ -0,0 +1,69 @@ +from collections import OrderedDict +from typing import List +import json +import logging + +from pydantic import BaseModel + +from .relevance import RelevanceExample, LuceneDocumentLoader +from pygaggle.model.tokenize import SpacySenticizer +from pygaggle.rerank import Query, Text + + +__all__ = ['MISSING_ID', 'LitReviewCategory', 'LitReviewAnswer', 'LitReviewDataset', 'LitReviewSubcategory'] + + +MISSING_ID = '' + + +class LitReviewAnswer(BaseModel): + id: str + title: str + exact_answer: str + + +class LitReviewSubcategory(BaseModel): + name: str + answers: List[LitReviewAnswer] + + +class LitReviewCategory(BaseModel): + name: str + sub_categories: List[LitReviewSubcategory] + + +class LitReviewDataset(BaseModel): + categories: List[LitReviewCategory] + + @classmethod + def from_file(cls, filename: str) -> 'LitReviewDataset': + with open(filename) as f: + return cls(**json.load(f)) + + @property + def query_answer_pairs(self): + return ((subcat.name, ans) for cat in self.categories + for subcat in cat.sub_categories + for ans in subcat.answers) + + def to_senticized_dataset(self, index_path: str) -> List[RelevanceExample]: + loader = LuceneDocumentLoader(index_path) + tokenizer = SpacySenticizer() + example_map = OrderedDict() + rel_map = OrderedDict() + for query, document in self.query_answer_pairs: + if document.id == MISSING_ID: + logging.warning(f'Skipping {document.title} (missing ID)') + continue + key = (query, document.id) + example_map.setdefault(key, tokenizer(loader.load_document(document.id))) + sents = example_map[key] + rel_map.setdefault(key, [False] * len(sents)) + for idx, s in enumerate(sents): + if document.exact_answer in s: + rel_map[key][idx] = True + for (_, doc_id), rels in rel_map.items(): + if not any(rels): + logging.warning(f'{doc_id} has no relevant answers') + return [RelevanceExample(Query(query), list(map(lambda s: Text(s, dict(docid=docid)), sents)), rels) + for ((query, docid), sents), (_, rels) in zip(example_map.items(), rel_map.items())] diff --git a/pygaggle/data/relevance.py b/pygaggle/data/relevance.py new file mode 100644 index 00000000..b53f5982 --- /dev/null +++ b/pygaggle/data/relevance.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from functools import lru_cache +from itertools import chain +from typing import List +import json +import re + +from pyserini.search import pysearch + +from pygaggle.rerank import Query, Text + + +__all__ = ['RelevanceExample', 'LuceneDocumentLoader'] + + +@dataclass +class RelevanceExample: + query: Query + documents: List[Text] + labels: List[bool] + + +class LuceneDocumentLoader: + double_space_pattern = re.compile(r'\s\s+') + + def __init__(self, index_path: str): + self.searcher = pysearch.SimpleSearcher(index_path) + + @lru_cache(maxsize=1024) + def load_document(self, id: str) -> str: + article = json.loads(self.searcher.doc(id).lucene_document().get('raw')) + ref_entries = article['ref_entries'].values() + text = '\n'.join(x['text'] for x in chain(article['abstract'], article['body_text'], ref_entries)) + return text diff --git a/pygaggle/lib/IdentityReranker.py b/pygaggle/lib/IdentityReranker.py deleted file mode 100644 index b2898144..00000000 --- a/pygaggle/lib/IdentityReranker.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import List - -from pygaggle.rerank import Reranker, Query, Text - - -class IdentityReranker(Reranker): - """A reranker that simply returns a clone of the input list of texts. - """ - - def rerank(self, query: Query, texts: List[Text]): - output = [] - for text in texts: - output.append(Text(text.contents, text.raw, text.score)) - return output diff --git a/pygaggle/lib/__init__.py b/pygaggle/lib/__init__.py deleted file mode 100644 index 7fbc0d81..00000000 --- a/pygaggle/lib/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ - -# We use __init__ to flatten the hierarchy so that we can do: -# > from pygaggle.lib import IdentityReranker -# -# which is less verbose than -# > from pygaggle.lib.IdentityReranker import IdentityReranker -# - -from pygaggle.lib.IdentityReranker import IdentityReranker diff --git a/pygaggle/logger.py b/pygaggle/logger.py new file mode 100644 index 00000000..0f979b02 --- /dev/null +++ b/pygaggle/logger.py @@ -0,0 +1,7 @@ +import coloredlogs + + +__all__ = [] + + +coloredlogs.install(level='INFO', fmt='%(asctime)s [%(levelname)s] %(module)s: %(message)s') diff --git a/pygaggle/model/__init__.py b/pygaggle/model/__init__.py new file mode 100644 index 00000000..f479547d --- /dev/null +++ b/pygaggle/model/__init__.py @@ -0,0 +1,5 @@ +from .decode import * +from .encode import * +from .evaluate import * +from .serialize import * +from .tokenize import * diff --git a/pygaggle/model/decode.py b/pygaggle/model/decode.py new file mode 100644 index 00000000..3b26f301 --- /dev/null +++ b/pygaggle/model/decode.py @@ -0,0 +1,29 @@ +from typing import Union, Tuple + +from transformers import PreTrainedModel +import torch + + +__all__ = ['greedy_decode'] + + +@torch.no_grad() +def greedy_decode(model: PreTrainedModel, + input_ids: torch.Tensor, + length: int, + attention_mask: torch.Tensor = None, + return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + decode_ids = torch.full((input_ids.size(0), 1), + model.config.decoder_start_token_id, + dtype=torch.long).to(input_ids.device) + past = model.get_encoder()(input_ids, attention_mask=attention_mask) + next_token_logits = None + for _ in range(length): + model_inputs = model.prepare_inputs_for_generation(decode_ids, past=past, attention_mask=attention_mask) + outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size) + next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size) + decode_ids = torch.cat([decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1) + past = outputs[1] + if return_last_logits: + return decode_ids, next_token_logits + return decode_ids diff --git a/pygaggle/model/encode.py b/pygaggle/model/encode.py new file mode 100644 index 00000000..ddfc9be4 --- /dev/null +++ b/pygaggle/model/encode.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import List + +from transformers import PreTrainedTokenizer +import torch +import torch.nn as nn + +from .tokenize import BatchTokenizer +from pygaggle.rerank import TextType + + +__all__ = ['LongBatchEncoder', 'EncoderOutputBatch', 'SingleEncoderOutput', 'SpecialTokensCleaner'] + + +@dataclass +class SingleEncoderOutput: + encoder_output: torch.Tensor + token_ids: torch.Tensor + text: TextType + + +@dataclass +class EncoderOutputBatch: + encoder_output: List[torch.Tensor] + token_ids: List[torch.Tensor] + texts: List[TextType] + + def as_single(self) -> 'SingleEncoderOutput': + return SingleEncoderOutput(self.encoder_output[0], self.token_ids[0], self.texts[0]) + + def __iter__(self): + return iter(SingleEncoderOutput(enc_out, token_ids, text) for enc_out, token_ids, text + in zip(self.encoder_output, self.token_ids, self.texts)) + + +class SpecialTokensCleaner: + def __init__(self, tokenizer: PreTrainedTokenizer): + self.special_ids = tokenizer.all_special_ids + + def clean(self, output: SingleEncoderOutput) -> SingleEncoderOutput: + indices = [idx for idx, tok in enumerate(output.token_ids.tolist()) if tok not in self.special_ids] + return SingleEncoderOutput(output.encoder_output[indices], output.token_ids[indices], output.text) + + +class LongBatchEncoder: + """Encodes batches of documents that are longer than the maximum sequence length by striding a window across + the sequence dimension. + + Parameters + ---------- + encoder : nn.Module + The encoder module, such as `BertModel`. + tokenizer : BatchTokenizer + The batch tokenizer to use. + max_seq_length : int + The maximum sequence length, typically 512. + """ + def __init__(self, + encoder: nn.Module, + tokenizer: BatchTokenizer, + max_seq_length: int = 512): + self.encoder = encoder + self.device = next(self.encoder.parameters()).device + self.tokenizer = tokenizer + self.msl = max_seq_length + + def encode_single(self, input: TextType) -> SingleEncoderOutput: + return self.encode([input]).as_single() + + def encode(self, batch_input: List[TextType]) -> EncoderOutputBatch: + batch_output = [] + batch_ids = [] + for ret in self.tokenizer.traverse(batch_input): + input_ids = ret.output['input_ids'] + lengths = list(map(len, input_ids)) + batch_ids.extend(map(torch.tensor, input_ids)) + input_ids = [(idx, x) for idx, x in enumerate(input_ids)] + max_len = min(max(lengths), self.msl) + encode_lst = [[] for _ in input_ids] + new_input_ids = [(idx, x[:max_len]) for idx, x in input_ids] + while new_input_ids: + attn_mask = [[1] * len(x[1]) + [0] * (max_len - len(x[1])) for x in new_input_ids] + attn_mask = torch.tensor(attn_mask).to(self.device) + nonpadded_input_ids = new_input_ids + new_input_ids = [x + [0] * (max_len - len(x[:max_len])) for _, x in new_input_ids] + new_input_ids = torch.tensor(new_input_ids).to(self.device) + outputs, _ = self.encoder(input_ids=new_input_ids, attention_mask=attn_mask) + for (idx, _), output in zip(nonpadded_input_ids, outputs): + encode_lst[idx].append(output) + + new_input_ids = [(idx, x[max_len:]) for idx, x in nonpadded_input_ids if len(x) > max_len] + max_len = min(max(map(lambda x: len(x[1]), new_input_ids), default=0), self.msl) + + encode_lst = list(map(torch.cat, encode_lst)) + batch_output.extend(encode_lst) + return EncoderOutputBatch(batch_output, batch_ids, batch_input) diff --git a/pygaggle/model/evaluate.py b/pygaggle/model/evaluate.py new file mode 100644 index 00000000..e0a39906 --- /dev/null +++ b/pygaggle/model/evaluate.py @@ -0,0 +1,103 @@ +from collections import OrderedDict +from typing import List +import abc + +from sklearn.metrics import recall_score +from tqdm import tqdm +import numpy as np + +from pygaggle.data import RelevanceExample +from pygaggle.rerank import Reranker + + +__all__ = ['RerankerEvaluator', 'metric_names'] +METRIC_MAP = OrderedDict() + + +class MetricAccumulator: + name: str = None + + def accumulate(self, scores: List[float], gold: List[RelevanceExample]): + return + + @abc.abstractmethod + def value(self): + return + + +class MeanAccumulator(MetricAccumulator): + def __init__(self): + self.scores = [] + + @property + def value(self): + return np.mean(self.scores) + + +def register_metric(name): + def wrap_fn(metric_cls): + METRIC_MAP[name] = metric_cls + metric_cls.name = name + return metric_cls + return wrap_fn + + +def metric_names(): + return list(METRIC_MAP.keys()) + + +def truncated_rels(scores: List[float], top_k: int) -> np.ndarray: + rel_idxs = sorted(list(enumerate(scores)), key=lambda x: x[1], reverse=True)[:top_k] + rel_idxs = [x[0] for x in rel_idxs] + score_rels = np.zeros(len(scores), dtype=int) + score_rels[rel_idxs] = 1 + return score_rels + + +@register_metric('recall') +class RecallAccumulator(MeanAccumulator): + top_k = None + + def accumulate(self, scores: List[float], gold: RelevanceExample): + score_rels = truncated_rels(scores, self.top_k) + gold_rels = np.array(gold.labels, dtype=int) + score = recall_score(gold_rels, score_rels, zero_division=1) + self.scores.append(score) + + +@register_metric('precision') +class PrecisionAccumulator(MeanAccumulator): + top_k = None + + def accumulate(self, scores: List[float], gold: RelevanceExample): + score_rels = truncated_rels(scores, self.top_k) + gold_rels = np.array(gold.labels, dtype=int) + self.scores.append((score_rels & gold_rels).sum() / score_rels.sum()) + + +@register_metric('recall@1') +class RecallAt1Metric(RecallAccumulator): + top_k = 1 + + +@register_metric('precision@1') +class PrecisionAt1Metric(PrecisionAccumulator): + top_k = 1 + + +class RerankerEvaluator: + def __init__(self, + reranker: Reranker, + metric_names: List[str], + use_tqdm: bool = True): + self.reranker = reranker + self.metrics = [METRIC_MAP[name] for name in metric_names] + self.use_tqdm = use_tqdm + + def evaluate(self, examples: List[RelevanceExample]) -> List[MetricAccumulator]: + metrics = [cls() for cls in self.metrics] + for example in tqdm(examples, disable=not self.use_tqdm): + scores = [x.score for x in self.reranker.rerank(example.query, example.documents)] + for metric in metrics: + metric.accumulate(scores, example) + return metrics diff --git a/pygaggle/model/serialize.py b/pygaggle/model/serialize.py new file mode 100644 index 00000000..e81daa91 --- /dev/null +++ b/pygaggle/model/serialize.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from pathlib import Path +import json +import logging +import os +import re + +from tensorflow.python.lib.io import file_io +from transformers import T5Config, T5ForConditionalGeneration +import torch + + +__all__ = ['CachedT5ModelLoader'] +TRANSFO_PREFIX = 'model.ckpt' + + +@dataclass +class CachedT5ModelLoader: + url: str + cache_path: Path + cache_key: str + model_type: str = 't5-base' + flush_cache: bool = False + + def __post_init__(self): + self.ckpt_url = os.path.join(self.url, 'checkpoint') + self.model_cache_dir = self.cache_path / self.cache_key + self.model_cache_dir.mkdir(exist_ok=True, parents=True) + assert file_io.file_exists(self.ckpt_url), 'checkpoint file missing' + self.ckpt_prefix = file_io.read_file_to_string(self.ckpt_url) + + def _fix_t5_model(self, model: T5ForConditionalGeneration): + with torch.no_grad(): # Make more similar to TensorFlow implementation + model.decoder.block[0].layer[1].EncDecAttention.relative_attention_bias.weight.data.zero_() + return model + + def load(self) -> T5ForConditionalGeneration: + try: + if not self.flush_cache: + return self._fix_t5_model(T5ForConditionalGeneration.from_pretrained(str(self.model_cache_dir), + from_tf=True, + force_download=False)) + except (RuntimeError, OSError): + logging.info('T5 model weights not in cache.') + m = re.search(r'model_checkpoint_path: "(.+?)"', self.ckpt_prefix) + assert m is not None, 'checkpoint file malformed' + + # Copy over checkpoint data + ckpt_patt = re.compile(rf'^{m.group(1)}\.(data-\d+-of-\d+|index|meta)$') + for name in file_io.list_directory(self.url): + if not ckpt_patt.match(name): + continue + url = os.path.join(self.url, name) + url_stat = file_io.stat(url) + cache_file_path = self.model_cache_dir / ckpt_patt.sub(rf'{TRANSFO_PREFIX}.\1', name) + try: + cs = os.stat(str(cache_file_path)) + if cs.st_size == url_stat.length and cs.st_mtime_ns > url_stat.mtime_nsec and not self.flush_cache: + logging.info(f'Skipping {name}...') + continue + except FileNotFoundError: + pass + logging.info(f'Caching {name}...') + file_io.copy(url, str(cache_file_path), overwrite=True) + + # Transformers expects a model config.json + config = T5Config.from_pretrained(self.model_type) + with open(str(self.model_cache_dir / 'config.json'), 'w') as f: + json.dump(config.__dict__, f, indent=4) + return self._fix_t5_model(T5ForConditionalGeneration.from_pretrained(str(self.model_cache_dir), + from_tf=True, + force_download=False)) diff --git a/pygaggle/model/tokenize.py b/pygaggle/model/tokenize.py new file mode 100644 index 00000000..4a23dca3 --- /dev/null +++ b/pygaggle/model/tokenize.py @@ -0,0 +1,132 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Mapping, Union, Iterable, Optional + +from spacy.lang.en import English +from transformers import PreTrainedTokenizer +import torch + +from pygaggle.rerank import Query, Text, TextType + + +__all__ = ['BatchTokenizer', + 'T5BatchTokenizer', + 'QueryDocumentBatch', + 'SimpleBatchTokenizer', + 'QueryDocumentBatchTokenizer', + 'SpacySenticizer', + 'SpacyWordTokenizer'] +TokenizerReturnType = Mapping[str, Union[torch.Tensor, List[int], List[List[int]], List[List[str]]]] + + +@dataclass +class TokenizerOutputBatch: + output: TokenizerReturnType + texts: List[TextType] + + def __len__(self): + return len(self.texts) + + +@dataclass +class QueryDocumentBatch: + query: Query + documents: List[Text] + output: Optional[TokenizerReturnType] = None + + def __len__(self): + return len(self.documents) + + +class TokenizerEncodeMixin: + tokenizer: PreTrainedTokenizer = None + tokenizer_kwargs = None + + def encode(self, strings: List[str]) -> TokenizerReturnType: + assert self.tokenizer and self.tokenizer_kwargs is not None, 'mixin used improperly' + ret = self.tokenizer.batch_encode_plus(strings, **self.tokenizer_kwargs) + ret['tokens'] = list(map(self.tokenizer.tokenize, strings)) + return ret + + +class BatchTokenizer(TokenizerEncodeMixin): + def __init__(self, + tokenizer: PreTrainedTokenizer, + batch_size: int, + **tokenizer_kwargs): + self.tokenizer = tokenizer + self.batch_size = batch_size + self.tokenizer_kwargs = tokenizer_kwargs + + def traverse(self, batch_input: List[TextType]) -> Iterable[TokenizerOutputBatch]: + for batch_idx in range(0, len(batch_input), self.batch_size): + inputs = batch_input[batch_idx:batch_idx + self.batch_size] + input_ids = self.encode([x.text for x in inputs]) + yield TokenizerOutputBatch(input_ids, inputs) + + +class AppendEosTokenizerMixin: + tokenizer: PreTrainedTokenizer = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode(self, strings: List[str]) -> TokenizerReturnType: + assert self.tokenizer, 'mixin used improperly' + return super().encode([f'{x} {self.tokenizer.eos_token}' for x in strings]) + + +class QueryDocumentBatchTokenizer(TokenizerEncodeMixin): + def __init__(self, + tokenizer: PreTrainedTokenizer, + batch_size: int, + pattern: str = '{query} {document}', + **tokenizer_kwargs): + self.tokenizer = tokenizer + self.batch_size = batch_size + self.tokenizer_kwargs = tokenizer_kwargs + self.pattern = pattern + + def traverse_query_document(self, batch_input: QueryDocumentBatch) -> Iterable[QueryDocumentBatch]: + query = batch_input.query + for batch_idx in range(0, len(batch_input), self.batch_size): + docs = batch_input.documents[batch_idx:batch_idx + self.batch_size] + outputs = self.encode([self.pattern.format(query=query.text, document=doc.text) for doc in docs]) + yield QueryDocumentBatch(query, docs, outputs) + + +class T5BatchTokenizer(AppendEosTokenizerMixin, QueryDocumentBatchTokenizer): + def __init__(self, *args, **kwargs): + kwargs['pattern'] = 'Query: {query} Document: {document} Relevant:' + kwargs['return_attention_mask'] = True + kwargs['pad_to_max_length'] = True + kwargs['return_tensors'] = 'pt' + super().__init__(*args, **kwargs) + + +class SimpleBatchTokenizer(BatchTokenizer): + def __init__(self, *args, **kwargs): + kwargs['return_attention_mask'] = True + kwargs['pad_to_max_length'] = True + super().__init__(*args, **kwargs) + + +class SpacyWordTokenizer: + nlp = English() + tokenizer = nlp.Defaults.create_tokenizer(nlp) + + @lru_cache(maxsize=1024) + def __call__(self, text: str) -> List[str]: + return list(x.text for x in self.tokenizer(text)) + + +class SpacySenticizer: + nlp = English() + nlp.add_pipe(nlp.create_pipe('sentencizer')) + + def __init__(self, max_paragraph_length: int = None): + self.max_paragraph_length = max_paragraph_length + + @lru_cache(maxsize=1024) + def __call__(self, document: str) -> List[str]: + return [s.string for s in self.nlp(document[:self.max_paragraph_length]).sents] diff --git a/pygaggle/rerank/__init__.py b/pygaggle/rerank/__init__.py new file mode 100644 index 00000000..cb01d169 --- /dev/null +++ b/pygaggle/rerank/__init__.py @@ -0,0 +1,4 @@ +from .base import * +from .similarity import * +from .bm25 import * +from .transformer import * diff --git a/pygaggle/rerank.py b/pygaggle/rerank/base.py similarity index 73% rename from pygaggle/rerank.py rename to pygaggle/rerank/base.py index 2ca3444d..6c47bd92 100644 --- a/pygaggle/rerank.py +++ b/pygaggle/rerank/base.py @@ -1,8 +1,15 @@ -from typing import List +from typing import List, Union, Optional, Mapping, Any +import abc from pyserini.pyclass import JSimpleSearcherResult +__all__ = ['Query', 'Text', 'Reranker', 'to_texts', 'TextType'] + + +TextType = Union['Query', 'Text'] + + class Query: """Class representing a query. A query contains the query text itself and potentially other metadata. @@ -21,17 +28,18 @@ class Text: Parameters ---------- - contents : str + text : str The text to be reranked. - raw : str - The raw representation of the text to be ranked. For example, the ```raw``` might be a JSON object containing - the ```contents``` as well as additional metadata data and other annotations. - score : float + raw : Mapping[str, Any] + Additional metadata and other annotations. + score : Optional[float] The score of the text. For example, the score might be the BM25 score from an initial retrieval stage. """ - def __init__(self, contents: str, raw: str, score: float): - self.contents = contents + def __init__(self, text: str, raw: Mapping[str, Any] = None, score: Optional[float] = 0): + self.text = text + if raw is None: + raw = dict() self.raw = raw self.score = score @@ -40,11 +48,8 @@ class Reranker: """Class representing a reranker. A reranker takes a list texts and returns a list of texts non-destructively (i.e., does not alter the original input list of texts). """ - - def __init__(self): - pass - - def rerank(self, query: Query, texts: List[Text]): + @abc.abstractmethod + def rerank(self, query: Query, texts: List[Text]) -> List[Text]: """Reranks a list of texts with respect to a query. Parameters @@ -62,7 +67,7 @@ def rerank(self, query: Query, texts: List[Text]): pass -def to_texts(hits: List[JSimpleSearcherResult]): +def to_texts(hits: List[JSimpleSearcherResult]) -> List[Text]: """Converts hits from Pyserini into a list of texts. Parameters diff --git a/pygaggle/rerank/bm25.py b/pygaggle/rerank/bm25.py new file mode 100644 index 00000000..8460305d --- /dev/null +++ b/pygaggle/rerank/bm25.py @@ -0,0 +1,49 @@ +from collections import Counter +from copy import deepcopy +from typing import List +import math + +from pyserini.analysis.pyanalysis import get_lucene_analyzer, Analyzer +from pyserini.index.pyutils import IndexReaderUtils +import numpy as np + +from pygaggle.rerank import Reranker, Query, Text + + +__all__ = ['Bm25Reranker'] + + +class Bm25Reranker(Reranker): + def __init__(self, + k1: float = 1.6, + b: float = 0.75, + index_path: str = None): + self.k1 = k1 + self.b = b + self.use_corpus_estimator = False + self.analyzer = Analyzer(get_lucene_analyzer()) + if index_path: + self.use_corpus_estimator = True + self.index_utils = IndexReaderUtils(index_path) + + def rerank(self, query: Query, texts: List[Text]) -> List[Text]: + query_words = self.analyzer.analyze(query.text) + sentences = list(map(self.analyzer.analyze, (t.text for t in texts))) + + query_words_set = set(query_words) + sentence_sets = list(map(set, sentences)) + if not self.use_corpus_estimator: + idfs = {w: math.log(len(sentence_sets) / (1 + sum(int(w in sent) for sent in sentence_sets))) + for w in query_words_set} + mean_len = np.mean(list(map(len, sentences))) + d_len = len(sentences) + + texts = deepcopy(texts) + for sent_words, text in zip(sentences, texts): + tf = Counter(filter(query_words.__contains__, sent_words)) + if self.use_corpus_estimator: + idfs = {w: self.index_utils.compute_bm25_term_weight(text.raw['docid'], w) for w in tf} + score = sum(idfs[w] * tf[w] * (self.k1 + 1) / + (tf[w] + self.k1 * (1 - self.b + self.b * (d_len / mean_len))) for w in tf) + text.score = score + return texts diff --git a/pygaggle/rerank/similarity.py b/pygaggle/rerank/similarity.py new file mode 100644 index 00000000..729232f0 --- /dev/null +++ b/pygaggle/rerank/similarity.py @@ -0,0 +1,28 @@ +import abc + +import torch + +from pygaggle.model import SingleEncoderOutput + + +__all__ = ['SimilarityMatrixProvider', 'InnerProductMatrixProvider'] + + +class SimilarityMatrixProvider: + @abc.abstractmethod + def compute_matrix(self, + encoded_query: SingleEncoderOutput, + encoded_document: SingleEncoderOutput) -> torch.Tensor: + pass + + +class InnerProductMatrixProvider(SimilarityMatrixProvider): + @torch.no_grad() + def compute_matrix(self, encoded_query: SingleEncoderOutput, encoded_document: SingleEncoderOutput) -> torch.Tensor: + query_repr = encoded_query.encoder_output + doc_repr = encoded_document.encoder_output + matrix = torch.einsum('mh,nh->mn', query_repr, doc_repr) + dnorm = doc_repr.norm(p=2, dim=1).unsqueeze(0) + qnorm = query_repr.norm(p=2, dim=1).unsqueeze(1) + matrix = (matrix / (qnorm + 1e-7)) / (dnorm + 1e-7) + return matrix diff --git a/pygaggle/rerank/transformer.py b/pygaggle/rerank/transformer.py new file mode 100644 index 00000000..0314ed19 --- /dev/null +++ b/pygaggle/rerank/transformer.py @@ -0,0 +1,75 @@ +from copy import deepcopy +from typing import List + +from transformers import T5ForConditionalGeneration, PreTrainedModel +import torch + +from pygaggle.model import greedy_decode, QueryDocumentBatchTokenizer, BatchTokenizer,\ + QueryDocumentBatch, LongBatchEncoder, SpecialTokensCleaner +from pygaggle.rerank import Reranker, Query, Text, SimilarityMatrixProvider + + +__all__ = ['T5Reranker', 'TransformerReranker'] + + +class T5Reranker(Reranker): + def __init__(self, model: T5ForConditionalGeneration, tokenizer: QueryDocumentBatchTokenizer): + self.model = model + self.tokenizer = tokenizer + self.device = next(self.model.parameters(), None).device + + def rerank(self, query: Query, texts: List[Text]) -> List[Text]: + texts = deepcopy(texts) + batch_input = QueryDocumentBatch(query=query, documents=texts) + for batch in self.tokenizer.traverse_query_document(batch_input): + input_ids = batch.output['input_ids'] + attn_mask = batch.output['attention_mask'] + _, batch_scores = greedy_decode(self.model, + input_ids.to(self.device), + length=2, + attention_mask=attn_mask.to(self.device), + return_last_logits=True) + + # 6136 and 1176 are the indexes of the tokens false and true in T5. + batch_scores = batch_scores[:, [6136, 1176]] + batch_log_probs = torch.nn.functional.log_softmax(batch_scores, dim=1) + batch_log_probs = batch_log_probs[:, 1].tolist() + for doc, score in zip(batch.documents, batch_log_probs): + doc.score = score + return texts + + +class TransformerReranker(Reranker): + methods = dict(max=lambda x: x.max().item(), + mean=lambda x: x.mean().item(), + absmean=lambda x: x.abs().mean().item(), + absmax=lambda x: x.abs().max().item()) + + def __init__(self, + model: PreTrainedModel, + tokenizer: BatchTokenizer, + sim_matrix_provider: SimilarityMatrixProvider, + method: str = 'max', + clean_special: bool = True): + assert method in self.methods, 'inappropriate scoring method' + self.model = model + self.tokenizer = tokenizer + self.encoder = LongBatchEncoder(model, tokenizer) + self.sim_matrix_provider = sim_matrix_provider + self.method = method + self.clean_special = clean_special + self.cleaner = SpecialTokensCleaner(tokenizer.tokenizer) + self.device = next(self.model.parameters(), None).device + + @torch.no_grad() + def rerank(self, query: Query, texts: List[Text]) -> List[Text]: + encoded_query = self.encoder.encode_single(query) + encoded_documents = self.encoder.encode(texts) + texts = deepcopy(texts) + for enc_doc, text in zip(encoded_documents, texts): + if self.clean_special: + enc_doc = self.cleaner.clean(enc_doc) + matrix = self.sim_matrix_provider.compute_matrix(encoded_query, enc_doc) + score = self.methods[self.method](matrix) + text.score = score + return texts diff --git a/pygaggle/run/__init__.py b/pygaggle/run/__init__.py new file mode 100644 index 00000000..43148b79 --- /dev/null +++ b/pygaggle/run/__init__.py @@ -0,0 +1 @@ +from .args import * diff --git a/pygaggle/run/args.py b/pygaggle/run/args.py new file mode 100644 index 00000000..fb2787ad --- /dev/null +++ b/pygaggle/run/args.py @@ -0,0 +1,41 @@ +from typing import Any, Callable, Optional, Sequence +import argparse + + +__all__ = ['ArgumentParserBuilder', 'opt'] + + +def _make_parser_setter(option, key): + def fn(value): + option.kwargs[key] = value + return option + return fn + + +class ArgumentParserOption: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def __iter__(self): + return iter((self.args, self.kwargs)) + + def __getattr__(self, item: str): + if item == 'kwargs': + return self.kwargs + if item == 'args': + return self.args + return _make_parser_setter(self, item) + + +opt = ArgumentParserOption + + +class ArgumentParserBuilder(object): + def __init__(self, **init_kwargs): + self.parser = argparse.ArgumentParser(**init_kwargs) + + def add_opts(self, *options): + for args, kwargs in options: + self.parser.add_argument(*args, **kwargs) + return self.parser diff --git a/pygaggle/run/evaluate_kaggle_highlighter.py b/pygaggle/run/evaluate_kaggle_highlighter.py new file mode 100644 index 00000000..52d6505d --- /dev/null +++ b/pygaggle/run/evaluate_kaggle_highlighter.py @@ -0,0 +1,95 @@ +from typing import Optional, List +from pathlib import Path +import logging + +from pydantic import BaseModel, validator +from transformers import AutoModel, AutoTokenizer +import torch + +from .args import ArgumentParserBuilder, opt +from pygaggle.rerank import TransformerReranker, InnerProductMatrixProvider, Reranker, T5Reranker, Bm25Reranker +from pygaggle.model import SimpleBatchTokenizer, CachedT5ModelLoader, T5BatchTokenizer, RerankerEvaluator, metric_names +from pygaggle.data import LitReviewDataset +from pygaggle.settings import Settings + + +SETTINGS = Settings() +METHOD_CHOICES = ('transformer', 'bm25', 't5') + + +class KaggleEvaluationOptions(BaseModel): + dataset: Path + method: str + batch_size: int + device: str + metrics: List[str] + model_name: Optional[str] + + @validator('dataset') + def dataset_exists(cls, v: Path): + assert v.exists(), 'dataset must exist' + return v + + @validator('model_name') + def model_name_sane(cls, v: Optional[str], values, **kwargs): + method = values['method'] + if method == 'transformer' and v is None: + raise ValueError('transformer name must be specified') + elif method == 't5': + return SETTINGS.t5_model_type + if v == 'biobert': + return 'monologg/biobert_v1.1_pubmed' + return v + + +def construct_t5(options: KaggleEvaluationOptions) -> Reranker: + loader = CachedT5ModelLoader(SETTINGS.t5_model_dir, + SETTINGS.cache_dir, + 'ranker', + SETTINGS.t5_model_type, + SETTINGS.flush_cache) + device = torch.device(options.device) + model = loader.load().to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(options.model_name) + tokenizer = T5BatchTokenizer(tokenizer, options.batch_size) + return T5Reranker(model, tokenizer) + + +def construct_transformer(options: KaggleEvaluationOptions) -> Reranker: + device = torch.device(options.device) + model = AutoModel.from_pretrained(options.model_name).to(device).eval() + tokenizer = SimpleBatchTokenizer(AutoTokenizer.from_pretrained(options.model_name), options.batch_size) + provider = InnerProductMatrixProvider() + return TransformerReranker(model, tokenizer, provider) + + +def construct_bm25(_: KaggleEvaluationOptions) -> Reranker: + return Bm25Reranker(index_path=SETTINGS.cord19_index_path) + + +def main(): + apb = ArgumentParserBuilder() + apb.add_opts(opt('--dataset', type=Path, default='data/kaggle-lit-review.json'), + opt('--method', required=True, type=str, choices=METHOD_CHOICES), + opt('--model-name', type=str), + opt('--batch-size', '-bsz', type=int, default=96), + opt('--device', type=str, default='cuda:0'), + opt('--metrics', type=str, nargs='+', default=metric_names(), choices=metric_names())) + args = apb.parser.parse_args() + + options = KaggleEvaluationOptions(**vars(args)) + ds = LitReviewDataset.from_file(str(options.dataset)) + examples = ds.to_senticized_dataset(SETTINGS.cord19_index_path) + construct_map = dict(transformer=construct_transformer, bm25=construct_bm25, t5=construct_t5) + reranker = construct_map[options.method](options) + evaluator = RerankerEvaluator(reranker, options.metrics) + width = max(map(len, args.metrics)) + 1 + stdout = [] + for metric in evaluator.evaluate(examples): + logging.info(f'{metric.name:<{width}}{metric.value:.5}') + stdout.append(f'{metric.name.title()}\t{metric.value}') + print('\n'.join(stdout)) + + +if __name__ == '__main__': + main() diff --git a/pygaggle/settings.py b/pygaggle/settings.py new file mode 100644 index 00000000..53c3e30b --- /dev/null +++ b/pygaggle/settings.py @@ -0,0 +1,17 @@ +from pathlib import Path +import os + +from pydantic import BaseSettings + + +class Settings(BaseSettings): + cord19_index_path: str = 'data/lucene-index-covid-paragraph' + + # T5 model settings + t5_model_dir: str = 'gs://neuralresearcher_data/covid/data/model_exp304' + t5_model_type: str = 't5-base' + t5_max_length: int = 512 + + # Cache settings + cache_dir: Path = Path(os.getenv('XDG_CACHE_HOME', str(Path.home() / '.cache'))) / 'covidex' + flush_cache: bool = False diff --git a/scripts/evaluate-highlighters.sh b/scripts/evaluate-highlighters.sh new file mode 100644 index 00000000..c7599d1b --- /dev/null +++ b/scripts/evaluate-highlighters.sh @@ -0,0 +1,7 @@ +mkdir -p results +python -um pygaggle.run.evaluate_kaggle_highlighter --method bm25 > results/bm25.log +python -um pygaggle.run.evaluate_kaggle_highlighter --method t5 > results/t5.log +python -um pygaggle.run.evaluate_kaggle_highlighter --method transformer --model-name biobert > results/biobert.log +python -um pygaggle.run.evaluate_kaggle_highlighter --method transformer --model-name allenai/scibert_scivocab_cased > results/scibert.log +python -um pygaggle.run.evaluate_kaggle_highlighter --method transformer --model-name bert-base-cased > results/bert.log +for name in results/*; do echo $name; cat $name; echo; done diff --git a/scripts/update-index.sh b/scripts/update-index.sh new file mode 100644 index 00000000..70a46562 --- /dev/null +++ b/scripts/update-index.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +echo "Updating Anserini index..." + +INDEX_NAME=lucene-index-covid-paragraph-2020-04-10 +INDEX_URL=https://www.dropbox.com/s/ivk87journyajw3/lucene-index-covid-paragraph-2020-04-10.tar.gz + +wget ${INDEX_URL} +tar xvfz ${INDEX_NAME}.tar.gz && rm ${INDEX_NAME}.tar.gz + +rm -rf data/lucene-index-covid-paragraph +mv ${INDEX_NAME} data/lucene-index-covid-paragraph + +echo "Successfully updated Anserini index at data/${INDEX_NAME}" diff --git a/tests/test_base.py b/tests/test_base.py index 5b143100..b85d49c3 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -10,7 +10,7 @@ from pyserini.search import pysearch from pygaggle.rerank import to_texts, Text, Query, Reranker -from pygaggle.lib import IdentityReranker +from pygaggle.rerank import IdentityReranker class TestSearch(unittest.TestCase):