-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement highlighting rerankers (#1)
* Implement highlighting rerankers - Add T5, transformer, and BM25 rerankers - Add Kaggle dataset and evaluation framework * Fix README instructions - Add missing activation command * Fix BM25 bug - IDF not computed correctly * Improve IDF computation for BM25 reranker - Add option to compute IDF statistics from corpus * Add LongBatchEncoder documentation
- Loading branch information
Showing
28 changed files
with
1,486 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,18 @@ | ||
# pygaggle | ||
# PyGaggle | ||
|
||
A gaggle of CORD-19 rerankers. | ||
|
||
## Installation | ||
|
||
1. `conda env create -f environment.yml && conda activate pygaggle` | ||
|
||
2. Install [PyTorch 1.4+](http://pytorch.org/). | ||
|
||
3. Download the index: `sh scripts/update-index.sh` | ||
|
||
4. Make sure you have an installation of Java 8+: `javac --version` | ||
|
||
|
||
## Evaluating Highlighters | ||
|
||
Run `sh scripts/evaluate-highlighters.sh`. |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
name: pygaggle | ||
channels: | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1 | ||
- blas=1.0 | ||
- ca-certificates=2020.1.1 | ||
- certifi=2020.4.5.1 | ||
- freetype=2.9.1 | ||
- intel-openmp=2020.0 | ||
- jpeg=9b | ||
- libedit=3.1.20181209 | ||
- libffi=3.2.1 | ||
- libgcc-ng=9.1.0 | ||
- libgfortran-ng=7.3.0 | ||
- libpng=1.6.37 | ||
- libstdcxx-ng=9.1.0 | ||
- libtiff=4.1.0 | ||
- mkl=2020.0 | ||
- mkl-service=2.3.0 | ||
- mkl_fft=1.0.15 | ||
- mkl_random=1.1.0 | ||
- ncurses=6.2 | ||
- ninja=1.9.0 | ||
- numpy-base=1.18.1 | ||
- olefile=0.46 | ||
- openssl=1.1.1f | ||
- pillow=7.0.0 | ||
- pip=20.0.2 | ||
- python=3.7.7 | ||
- readline=8.0 | ||
- setuptools=46.1.3 | ||
- six=1.14.0 | ||
- sqlite=3.31.1 | ||
- tk=8.6.8 | ||
- wheel=0.34.2 | ||
- xz=5.2.5 | ||
- zlib=1.2.11 | ||
- zstd=1.3.7 | ||
- pip: | ||
- absl-py==0.9.0 | ||
- astor==0.8.1 | ||
- blis==0.4.1 | ||
- boto3==1.12.41 | ||
- botocore==1.15.41 | ||
- cachetools==4.1.0 | ||
- catalogue==1.0.0 | ||
- chardet==3.0.4 | ||
- click==7.1.1 | ||
- coloredlogs==14.0 | ||
- cymem==2.0.3 | ||
- cython==0.29.16 | ||
- docutils==0.15.2 | ||
- filelock==3.0.12 | ||
- gast==0.2.2 | ||
- google-auth==1.14.0 | ||
- google-auth-oauthlib==0.4.1 | ||
- google-pasta==0.2.0 | ||
- grpcio==1.28.1 | ||
- h5py==2.10.0 | ||
- humanfriendly==8.2 | ||
- idna==2.9 | ||
- importlib-metadata==1.6.0 | ||
- jmespath==0.9.5 | ||
- joblib==0.14.1 | ||
- keras-applications==1.0.8 | ||
- keras-preprocessing==1.1.0 | ||
- markdown==3.2.1 | ||
- murmurhash==1.0.2 | ||
- numpy==1.18.2 | ||
- oauthlib==3.1.0 | ||
- opt-einsum==3.2.1 | ||
- plac==1.1.3 | ||
- preshed==3.0.2 | ||
- protobuf==3.11.3 | ||
- pyasn1==0.4.8 | ||
- pyasn1-modules==0.2.8 | ||
- pydantic==1.5 | ||
- pyjnius==1.2.1 | ||
- pyserini==0.9.0.0 | ||
- python-dateutil==2.8.1 | ||
- regex==2020.4.4 | ||
- requests==2.23.0 | ||
- requests-oauthlib==1.3.0 | ||
- rsa==4.0 | ||
- s3transfer==0.3.3 | ||
- sacremoses==0.0.41 | ||
- scikit-learn==0.22.2.post1 | ||
- scipy==1.4.1 | ||
- sentencepiece==0.1.85 | ||
- sklearn==0.0 | ||
- spacy==2.2.4 | ||
- srsly==1.0.2 | ||
- tensorboard==2.1.1 | ||
- tensorflow==2.1.0 | ||
- tensorflow-estimator==2.1.0 | ||
- tensorflow-gpu==2.1.0 | ||
- tensorflow-text==2.1.1 | ||
- termcolor==1.1.0 | ||
- thinc==7.4.0 | ||
- tokenizers==0.5.2 | ||
- tqdm==4.45.0 | ||
- transformers==2.7.0 | ||
- urllib3==1.25.9 | ||
- wasabi==0.6.0 | ||
- werkzeug==1.0.1 | ||
- wrapt==1.12.1 | ||
- zipp==3.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
|
||
from .logger import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .kaggle import * | ||
from .relevance import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from collections import OrderedDict | ||
from typing import List | ||
import json | ||
import logging | ||
|
||
from pydantic import BaseModel | ||
|
||
from .relevance import RelevanceExample, LuceneDocumentLoader | ||
from pygaggle.model.tokenize import SpacySenticizer | ||
from pygaggle.rerank import Query, Text | ||
|
||
|
||
__all__ = ['MISSING_ID', 'LitReviewCategory', 'LitReviewAnswer', 'LitReviewDataset', 'LitReviewSubcategory'] | ||
|
||
|
||
MISSING_ID = '<missing>' | ||
|
||
|
||
class LitReviewAnswer(BaseModel): | ||
id: str | ||
title: str | ||
exact_answer: str | ||
|
||
|
||
class LitReviewSubcategory(BaseModel): | ||
name: str | ||
answers: List[LitReviewAnswer] | ||
|
||
|
||
class LitReviewCategory(BaseModel): | ||
name: str | ||
sub_categories: List[LitReviewSubcategory] | ||
|
||
|
||
class LitReviewDataset(BaseModel): | ||
categories: List[LitReviewCategory] | ||
|
||
@classmethod | ||
def from_file(cls, filename: str) -> 'LitReviewDataset': | ||
with open(filename) as f: | ||
return cls(**json.load(f)) | ||
|
||
@property | ||
def query_answer_pairs(self): | ||
return ((subcat.name, ans) for cat in self.categories | ||
for subcat in cat.sub_categories | ||
for ans in subcat.answers) | ||
|
||
def to_senticized_dataset(self, index_path: str) -> List[RelevanceExample]: | ||
loader = LuceneDocumentLoader(index_path) | ||
tokenizer = SpacySenticizer() | ||
example_map = OrderedDict() | ||
rel_map = OrderedDict() | ||
for query, document in self.query_answer_pairs: | ||
if document.id == MISSING_ID: | ||
logging.warning(f'Skipping {document.title} (missing ID)') | ||
continue | ||
key = (query, document.id) | ||
example_map.setdefault(key, tokenizer(loader.load_document(document.id))) | ||
sents = example_map[key] | ||
rel_map.setdefault(key, [False] * len(sents)) | ||
for idx, s in enumerate(sents): | ||
if document.exact_answer in s: | ||
rel_map[key][idx] = True | ||
for (_, doc_id), rels in rel_map.items(): | ||
if not any(rels): | ||
logging.warning(f'{doc_id} has no relevant answers') | ||
return [RelevanceExample(Query(query), list(map(lambda s: Text(s, dict(docid=docid)), sents)), rels) | ||
for ((query, docid), sents), (_, rels) in zip(example_map.items(), rel_map.items())] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from dataclasses import dataclass | ||
from functools import lru_cache | ||
from itertools import chain | ||
from typing import List | ||
import json | ||
import re | ||
|
||
from pyserini.search import pysearch | ||
|
||
from pygaggle.rerank import Query, Text | ||
|
||
|
||
__all__ = ['RelevanceExample', 'LuceneDocumentLoader'] | ||
|
||
|
||
@dataclass | ||
class RelevanceExample: | ||
query: Query | ||
documents: List[Text] | ||
labels: List[bool] | ||
|
||
|
||
class LuceneDocumentLoader: | ||
double_space_pattern = re.compile(r'\s\s+') | ||
|
||
def __init__(self, index_path: str): | ||
self.searcher = pysearch.SimpleSearcher(index_path) | ||
|
||
@lru_cache(maxsize=1024) | ||
def load_document(self, id: str) -> str: | ||
article = json.loads(self.searcher.doc(id).lucene_document().get('raw')) | ||
ref_entries = article['ref_entries'].values() | ||
text = '\n'.join(x['text'] for x in chain(article['abstract'], article['body_text'], ref_entries)) | ||
return text |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import coloredlogs | ||
|
||
|
||
__all__ = [] | ||
|
||
|
||
coloredlogs.install(level='INFO', fmt='%(asctime)s [%(levelname)s] %(module)s: %(message)s') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .decode import * | ||
from .encode import * | ||
from .evaluate import * | ||
from .serialize import * | ||
from .tokenize import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typing import Union, Tuple | ||
|
||
from transformers import PreTrainedModel | ||
import torch | ||
|
||
|
||
__all__ = ['greedy_decode'] | ||
|
||
|
||
@torch.no_grad() | ||
def greedy_decode(model: PreTrainedModel, | ||
input_ids: torch.Tensor, | ||
length: int, | ||
attention_mask: torch.Tensor = None, | ||
return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
decode_ids = torch.full((input_ids.size(0), 1), | ||
model.config.decoder_start_token_id, | ||
dtype=torch.long).to(input_ids.device) | ||
past = model.get_encoder()(input_ids, attention_mask=attention_mask) | ||
next_token_logits = None | ||
for _ in range(length): | ||
model_inputs = model.prepare_inputs_for_generation(decode_ids, past=past, attention_mask=attention_mask) | ||
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size) | ||
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size) | ||
decode_ids = torch.cat([decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1) | ||
past = outputs[1] | ||
if return_last_logits: | ||
return decode_ids, next_token_logits | ||
return decode_ids |
Oops, something went wrong.