Cross-Attending to Cached Context for Efficient LLM Inference, link to paper
We use the conda
environment for this project.
environment for this project. Follow the steps below to set it up:
Ensure you have python3
and conda
installed on your system.
-
Create and activate the conda environment:
conda create --yes --name xccache python=3.10 conda activate xccache
-
Install the required Python packages:
python3 -m pip install -r requirements.txt
High-level principles:
- All imports are absolute; no relative imports.
- Specific functions/classes/etc. are next to where they are needed; general ones are closer to the root.
- A file's name should inform of its content.
- Group objects by theme, import dependencies and how they are used.
- Some folders reflect the hierarchy of other folders; pieces of the same "model" may be at different places in the repo, but under the same folder name.
- The bulk of the documentation lives with the appropriate code. There should not be a need for markdown files within the
xc_cache
folder: use file-level docstrings instead. (Markdown are acceptable inscripts
, especially if they document more than one script in the same folder.)
Specific folders:
entrypoint.py
: A minimalistic script that defers to scripts inscripts/
.xc_cache/
: Our code lives here. This is to be treated as a library; scripts go inscripts/
.xc_cache/data/
: Dataloader, pre-processing operations, etc. (Not the data itself!)xc_cache/model/
: Neural modules, etc. (Not training code nor saved weights!) Subfolders are used to group main "chunks" of the project.xc_cache/train/
: Trainers, model hyperparameters, etc. The subfolders here reflect those ofxc_cache/model/
.xc_cache/inference/
: Inference-specific code, including common interface to different language models.xc_cache/utils/
: Repo-wide utilities. Specific utilities may live closer to the code using them.
scripts/
: Python scripts parsing options and calling actual code inxc_cache/
.scripts/config
: General purpose configuration files to be used by scripts. Specificconfig
folders may appear deeper in the hierarchy.
We use Hugging Face to load our datasets. The training data is a mix of publicly available datasets. For details, refer to our paper, where we list the datasets used for training.
The training dataset is expected to include the following fields:
sample_idx
: Unique identifier for the sample.titles_list
: List of titles associated with the sample.contexts_list
: List of contexts for the sample.question
: The question to be answered.answer
: The ground truth answer.useful_contexts
: A list of binary values (0 or 1) indicating which contexts incontexts_list
are useful for answering the question.dataset
: Identifier for the dataset source.
Note:
Theuseful_contexts
field is particularly important for datasets like HotpotQA, wherecontexts_list
contains distractor contexts. This field helps specify which contexts are necessary for answering the question.
Our model evaluation scripts generate an inference dataset as part of the evaluation process. Each sample in the inference dataset is a copy of the corresponding sample from the test/validation set, augmented with additional field:
answer_pred
: The predicted answer generated by the model.
The inference dataset helps assess model performance and analyze discrepancies between predictions and ground truth answers.
- Make sure the HuggingFace variables are appropriately set, we use Weights & Biases to monitor the experiment logs:
HF_HUB_TOKEN_PATH
HF_TOKEN=$(head -n 1 $HF_HUB_TOKEN_PATH)
huggingface-cli login
wandb login
- Select one of the experiments from EXP_GROUPS, and train its cross-attention model:
python3 entrypoint.py --module cross_attn_finetune --exp_group {$EXPERIMENT_NAME} --exp_id {$SOME_ID} --savedir_base {$SAVE_FOLDER} --data_path {$HF_TRAINING_DATA}
- The model checkpoints will be saved in
{$SAVE_FOLDER}/{$SOME_ID}
.
Evaluation applies a model to a given dataset, and saves the generated answers for the samples in a new HF dataset (an inference dataset
) that contains the original data, plus the predicted answer. This process does not compute evaluation metrics. See section on Metric Computations
to see how to generate scores given an inference dataset.
python3 entrypoint.py --module qa_evaluation --model_ckpt {$CHECKPOINT} --dataset {$HF_TESTING_DATA} --model_output_path {$OUTPUT_FOLDER} --to_device cuda
python3 entrypoint.py --module qa_evaluation --model_ckpt {$CHECKPOINT} --dataset {$HF_TESTING_DATA} --model_output_path {$OUTPUT_FOLDER} --aggregate
The inference dataset will then be saved in {$OUTPUT_FOLDER}/{$HF_TESTING_DATA}-test
.
To evaluate FiD, first try running it on a subset of 10 samples, as follows:
python3 entrypoint.py --module qa_evaluation --task_format=fid --subset_size 10 --model_path={$YOUR_MODEL_PATH} --dataset={$HF_TESTING_DATA} --dataset_split=test_answered --task_answer=newline --task_context=only_gold_long --post_cleanup=sept2dec --model_output_path={$TEMP_OUTPUT_FOLDER} --max_new_tokens 5 --to_device cuda
To evaluate on the whole dataset, first run the model on the dataset shards, then aggregate the outputs:
deepspeed --num_gpus 1 entrypoint.py --module qa_evaluation --task_format fid --model_path {$YOUR_MODEL_PATH} --dataset_name=long_nq_dedup --dataset_split=test_answered --task_answer=newline --task_context=only_gold_long --post_cleanup=sept2dec --model_max_length 200 --model_output_path fid-tmp --dataset_num_shards=100 --ds_config=./scripts/config/deepspeed_inference.json
deepspeed --num_gpus 1 entrypoint.py --module qa_evaluation --task_format fid --model_path {$YOUR_MODEL_PATH} --dataset_name=long_nq_dedup --dataset_split=test_answered --task_answer=newline --task_context=only_gold_long --post_cleanup=sept2dec --model_max_length 200 --model_output_path fid-tmp --dataset_num_shards=100 --ds_config=./scripts/config/deepspeed_inference.json --aggregate
This code works with the official pretrained models, base and large, available here. Follow the instructions to download them.
Once you have an inference dataset generated and saved by applying a model to your test data, you may now use it to compute some metrics. We have two classes of metrics: performance and faithfulness metrics. As a quick start, run the following to compute the k-precision and rouge score for your dataset.
python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --score_dir {$SCORE_SAVE_PATH} --rouge --kprecision
This will compute the rouge and k-precision scores for the given data and save the results in $SCORE_SAVE_PATH
.
The scores are saved inside score_dir
in a single json file whose name starts with eval_scores_{model_info}
.
If no score_dir
is specified, then the score json will be stored in the same directory as the dataset, in inf_dataset_path
.
If the --store_individual_scores
option is set, then the score of each sample for the given metric will also be saved inside a subdirectory named after the metric name.
To compute all possible metrics (both faithfulness metrics and inference metrics) EXCEPT the llm-based ones, you may run:
python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --score_dir {$SCORE_SAVE_PATH} --all_metrics --all_faith_metrics
Equivalently, the switch --acl_metrics
will compute all the metrics reported in our [https://aclanthology.org/2024.findings-emnlp.896/](EMNLP paper).
It is sometimes useful to postprocess the generated answers to remove EOS characters, ...etc. For this, use the answer_processing
flag. Options are postproc_X
or postproc_tulu2
.
python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --dataset_name nq --all_metrics --answer_processing postproc_X
It is possible to filter the inference dataset by dataset name for computing metrics on only the given dataset.
For this, use the dataset_name
flag. The options are nq
, hotpotqa
, topiocqa
.
python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --dataset_name nq
It is possible to select a subset of the data samples by specifying their index. For this, save the list of indices in a text file (one on each line) -- for example sample_list.txt
, and use the dataset_filter_idx_file
option to pass the path to sample_list.txt
.
The text file will be loaded as an array and used to filter the data samples before metric computation.
python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --dataset_name nq --dataset_filter_idx_file sample_list.txt
For samples that have multiple annotations for the same question/context pairs, it is possible to run metric computation by comparing the predicted answer to all possible GT answers and taking the max
(this is how NQ is evaluated in most papers).
To run such an evaluation, use the best_answer
switch.
python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --dataset_name nq --best_score
The script will first load the dataset, group by question/context, and evaluate by taking the best score for each sample.