Skip to content

Commit

Permalink
Implement highlighting rerankers (#1)
Browse files Browse the repository at this point in the history
* Implement highlighting rerankers

- Add T5, transformer, and BM25 rerankers
- Add Kaggle dataset and evaluation framework

* Fix README instructions

- Add missing activation command

* Fix BM25 bug

- IDF not computed correctly

* Improve IDF computation for BM25 reranker

- Add option to compute IDF statistics from corpus

* Add LongBatchEncoder documentation
  • Loading branch information
daemon authored Apr 21, 2020
1 parent 4e5df7c commit b2ab0c9
Show file tree
Hide file tree
Showing 28 changed files with 1,486 additions and 40 deletions.
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
# pygaggle
# PyGaggle

A gaggle of CORD-19 rerankers.

## Installation

1. `conda env create -f environment.yml && conda activate pygaggle`

2. Install [PyTorch 1.4+](http://pytorch.org/).

3. Download the index: `sh scripts/update-index.sh`

4. Make sure you have an installation of Java 8+: `javac --version`


## Evaluating Highlighters

Run `sh scripts/evaluate-highlighters.sh`.
459 changes: 459 additions & 0 deletions data/kaggle-lit-review.json

Large diffs are not rendered by default.

108 changes: 108 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
name: pygaggle
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1
- blas=1.0
- ca-certificates=2020.1.1
- certifi=2020.4.5.1
- freetype=2.9.1
- intel-openmp=2020.0
- jpeg=9b
- libedit=3.1.20181209
- libffi=3.2.1
- libgcc-ng=9.1.0
- libgfortran-ng=7.3.0
- libpng=1.6.37
- libstdcxx-ng=9.1.0
- libtiff=4.1.0
- mkl=2020.0
- mkl-service=2.3.0
- mkl_fft=1.0.15
- mkl_random=1.1.0
- ncurses=6.2
- ninja=1.9.0
- numpy-base=1.18.1
- olefile=0.46
- openssl=1.1.1f
- pillow=7.0.0
- pip=20.0.2
- python=3.7.7
- readline=8.0
- setuptools=46.1.3
- six=1.14.0
- sqlite=3.31.1
- tk=8.6.8
- wheel=0.34.2
- xz=5.2.5
- zlib=1.2.11
- zstd=1.3.7
- pip:
- absl-py==0.9.0
- astor==0.8.1
- blis==0.4.1
- boto3==1.12.41
- botocore==1.15.41
- cachetools==4.1.0
- catalogue==1.0.0
- chardet==3.0.4
- click==7.1.1
- coloredlogs==14.0
- cymem==2.0.3
- cython==0.29.16
- docutils==0.15.2
- filelock==3.0.12
- gast==0.2.2
- google-auth==1.14.0
- google-auth-oauthlib==0.4.1
- google-pasta==0.2.0
- grpcio==1.28.1
- h5py==2.10.0
- humanfriendly==8.2
- idna==2.9
- importlib-metadata==1.6.0
- jmespath==0.9.5
- joblib==0.14.1
- keras-applications==1.0.8
- keras-preprocessing==1.1.0
- markdown==3.2.1
- murmurhash==1.0.2
- numpy==1.18.2
- oauthlib==3.1.0
- opt-einsum==3.2.1
- plac==1.1.3
- preshed==3.0.2
- protobuf==3.11.3
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydantic==1.5
- pyjnius==1.2.1
- pyserini==0.9.0.0
- python-dateutil==2.8.1
- regex==2020.4.4
- requests==2.23.0
- requests-oauthlib==1.3.0
- rsa==4.0
- s3transfer==0.3.3
- sacremoses==0.0.41
- scikit-learn==0.22.2.post1
- scipy==1.4.1
- sentencepiece==0.1.85
- sklearn==0.0
- spacy==2.2.4
- srsly==1.0.2
- tensorboard==2.1.1
- tensorflow==2.1.0
- tensorflow-estimator==2.1.0
- tensorflow-gpu==2.1.0
- tensorflow-text==2.1.1
- termcolor==1.1.0
- thinc==7.4.0
- tokenizers==0.5.2
- tqdm==4.45.0
- transformers==2.7.0
- urllib3==1.25.9
- wasabi==0.6.0
- werkzeug==1.0.1
- wrapt==1.12.1
- zipp==3.1.0
2 changes: 1 addition & 1 deletion pygaggle/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@

from .logger import *
2 changes: 2 additions & 0 deletions pygaggle/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .kaggle import *
from .relevance import *
69 changes: 69 additions & 0 deletions pygaggle/data/kaggle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections import OrderedDict
from typing import List
import json
import logging

from pydantic import BaseModel

from .relevance import RelevanceExample, LuceneDocumentLoader
from pygaggle.model.tokenize import SpacySenticizer
from pygaggle.rerank import Query, Text


__all__ = ['MISSING_ID', 'LitReviewCategory', 'LitReviewAnswer', 'LitReviewDataset', 'LitReviewSubcategory']


MISSING_ID = '<missing>'


class LitReviewAnswer(BaseModel):
id: str
title: str
exact_answer: str


class LitReviewSubcategory(BaseModel):
name: str
answers: List[LitReviewAnswer]


class LitReviewCategory(BaseModel):
name: str
sub_categories: List[LitReviewSubcategory]


class LitReviewDataset(BaseModel):
categories: List[LitReviewCategory]

@classmethod
def from_file(cls, filename: str) -> 'LitReviewDataset':
with open(filename) as f:
return cls(**json.load(f))

@property
def query_answer_pairs(self):
return ((subcat.name, ans) for cat in self.categories
for subcat in cat.sub_categories
for ans in subcat.answers)

def to_senticized_dataset(self, index_path: str) -> List[RelevanceExample]:
loader = LuceneDocumentLoader(index_path)
tokenizer = SpacySenticizer()
example_map = OrderedDict()
rel_map = OrderedDict()
for query, document in self.query_answer_pairs:
if document.id == MISSING_ID:
logging.warning(f'Skipping {document.title} (missing ID)')
continue
key = (query, document.id)
example_map.setdefault(key, tokenizer(loader.load_document(document.id)))
sents = example_map[key]
rel_map.setdefault(key, [False] * len(sents))
for idx, s in enumerate(sents):
if document.exact_answer in s:
rel_map[key][idx] = True
for (_, doc_id), rels in rel_map.items():
if not any(rels):
logging.warning(f'{doc_id} has no relevant answers')
return [RelevanceExample(Query(query), list(map(lambda s: Text(s, dict(docid=docid)), sents)), rels)
for ((query, docid), sents), (_, rels) in zip(example_map.items(), rel_map.items())]
34 changes: 34 additions & 0 deletions pygaggle/data/relevance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from functools import lru_cache
from itertools import chain
from typing import List
import json
import re

from pyserini.search import pysearch

from pygaggle.rerank import Query, Text


__all__ = ['RelevanceExample', 'LuceneDocumentLoader']


@dataclass
class RelevanceExample:
query: Query
documents: List[Text]
labels: List[bool]


class LuceneDocumentLoader:
double_space_pattern = re.compile(r'\s\s+')

def __init__(self, index_path: str):
self.searcher = pysearch.SimpleSearcher(index_path)

@lru_cache(maxsize=1024)
def load_document(self, id: str) -> str:
article = json.loads(self.searcher.doc(id).lucene_document().get('raw'))
ref_entries = article['ref_entries'].values()
text = '\n'.join(x['text'] for x in chain(article['abstract'], article['body_text'], ref_entries))
return text
14 changes: 0 additions & 14 deletions pygaggle/lib/IdentityReranker.py

This file was deleted.

9 changes: 0 additions & 9 deletions pygaggle/lib/__init__.py

This file was deleted.

7 changes: 7 additions & 0 deletions pygaggle/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import coloredlogs


__all__ = []


coloredlogs.install(level='INFO', fmt='%(asctime)s [%(levelname)s] %(module)s: %(message)s')
5 changes: 5 additions & 0 deletions pygaggle/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .decode import *
from .encode import *
from .evaluate import *
from .serialize import *
from .tokenize import *
29 changes: 29 additions & 0 deletions pygaggle/model/decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Union, Tuple

from transformers import PreTrainedModel
import torch


__all__ = ['greedy_decode']


@torch.no_grad()
def greedy_decode(model: PreTrainedModel,
input_ids: torch.Tensor,
length: int,
attention_mask: torch.Tensor = None,
return_last_logits: bool = True) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
decode_ids = torch.full((input_ids.size(0), 1),
model.config.decoder_start_token_id,
dtype=torch.long).to(input_ids.device)
past = model.get_encoder()(input_ids, attention_mask=attention_mask)
next_token_logits = None
for _ in range(length):
model_inputs = model.prepare_inputs_for_generation(decode_ids, past=past, attention_mask=attention_mask)
outputs = model(**model_inputs) # (batch_size, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size, vocab_size)
decode_ids = torch.cat([decode_ids, next_token_logits.max(1)[1].unsqueeze(-1)], dim=-1)
past = outputs[1]
if return_last_logits:
return decode_ids, next_token_logits
return decode_ids
Loading

0 comments on commit b2ab0c9

Please sign in to comment.