Skip to content

Cross-Attending to Cached Context for Efficient LLM Inference

License

Notifications You must be signed in to change notification settings

ServiceNow/xc-cache

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

XC-Cache

Cross-Attending to Cached Context for Efficient LLM Inference, link to paper

Setup the environment

We use the conda environment for this project.

environment for this project. Follow the steps below to set it up:

Prerequisites

Ensure you have python3 and conda installed on your system.

Installation Steps

  1. Create and activate the conda environment:

    conda create --yes --name xccache python=3.10
    conda activate xccache
  2. Install the required Python packages:

    python3 -m pip install -r requirements.txt

Repository organization

High-level principles:

  • All imports are absolute; no relative imports.
  • Specific functions/classes/etc. are next to where they are needed; general ones are closer to the root.
  • A file's name should inform of its content.
  • Group objects by theme, import dependencies and how they are used.
  • Some folders reflect the hierarchy of other folders; pieces of the same "model" may be at different places in the repo, but under the same folder name.
  • The bulk of the documentation lives with the appropriate code. There should not be a need for markdown files within the xc_cache folder: use file-level docstrings instead. (Markdown are acceptable in scripts, especially if they document more than one script in the same folder.)

Specific folders:

  • entrypoint.py: A minimalistic script that defers to scripts in scripts/.
  • xc_cache/: Our code lives here. This is to be treated as a library; scripts go in scripts/.
    • xc_cache/data/: Dataloader, pre-processing operations, etc. (Not the data itself!)
    • xc_cache/model/: Neural modules, etc. (Not training code nor saved weights!) Subfolders are used to group main "chunks" of the project.
    • xc_cache/train/: Trainers, model hyperparameters, etc. The subfolders here reflect those of xc_cache/model/.
    • xc_cache/inference/: Inference-specific code, including common interface to different language models.
    • xc_cache/utils/: Repo-wide utilities. Specific utilities may live closer to the code using them.
  • scripts/: Python scripts parsing options and calling actual code in xc_cache/.
    • scripts/config: General purpose configuration files to be used by scripts. Specific config folders may appear deeper in the hierarchy.

Datasets

We use Hugging Face to load our datasets. The training data is a mix of publicly available datasets. For details, refer to our paper, where we list the datasets used for training.

Training Dataset Format

The training dataset is expected to include the following fields:

  • sample_idx: Unique identifier for the sample.
  • titles_list: List of titles associated with the sample.
  • contexts_list: List of contexts for the sample.
  • question: The question to be answered.
  • answer: The ground truth answer.
  • useful_contexts: A list of binary values (0 or 1) indicating which contexts in contexts_list are useful for answering the question.
  • dataset: Identifier for the dataset source.

Note:
The useful_contexts field is particularly important for datasets like HotpotQA, where contexts_list contains distractor contexts. This field helps specify which contexts are necessary for answering the question.

Inference Dataset Format

Our model evaluation scripts generate an inference dataset as part of the evaluation process. Each sample in the inference dataset is a copy of the corresponding sample from the test/validation set, augmented with additional field:

  • answer_pred: The predicted answer generated by the model.

The inference dataset helps assess model performance and analyze discrepancies between predictions and ground truth answers.

XC-Cache training instructions

  1. Make sure the HuggingFace variables are appropriately set, we use Weights & Biases to monitor the experiment logs:
HF_HUB_TOKEN_PATH
HF_TOKEN=$(head -n 1 $HF_HUB_TOKEN_PATH)
huggingface-cli login
wandb login
  1. Select one of the experiments from EXP_GROUPS, and train its cross-attention model:
    • python3 entrypoint.py --module cross_attn_finetune --exp_group {$EXPERIMENT_NAME} --exp_id {$SOME_ID} --savedir_base {$SAVE_FOLDER} --data_path {$HF_TRAINING_DATA}
  2. The model checkpoints will be saved in {$SAVE_FOLDER}/{$SOME_ID}.

Evaluation

Evaluation applies a model to a given dataset, and saves the generated answers for the samples in a new HF dataset (an inference dataset) that contains the original data, plus the predicted answer. This process does not compute evaluation metrics. See section on Metric Computations to see how to generate scores given an inference dataset.

python3 entrypoint.py --module qa_evaluation --model_ckpt {$CHECKPOINT} --dataset {$HF_TESTING_DATA} --model_output_path {$OUTPUT_FOLDER} --to_device cuda
python3 entrypoint.py --module qa_evaluation --model_ckpt {$CHECKPOINT} --dataset {$HF_TESTING_DATA} --model_output_path {$OUTPUT_FOLDER} --aggregate

The inference dataset will then be saved in {$OUTPUT_FOLDER}/{$HF_TESTING_DATA}-test.

Baselines

To evaluate FiD, first try running it on a subset of 10 samples, as follows:

python3 entrypoint.py --module qa_evaluation --task_format=fid --subset_size 10 --model_path={$YOUR_MODEL_PATH} --dataset={$HF_TESTING_DATA} --dataset_split=test_answered --task_answer=newline --task_context=only_gold_long --post_cleanup=sept2dec --model_output_path={$TEMP_OUTPUT_FOLDER} --max_new_tokens 5 --to_device cuda

To evaluate on the whole dataset, first run the model on the dataset shards, then aggregate the outputs:

deepspeed --num_gpus 1 entrypoint.py --module qa_evaluation --task_format fid --model_path {$YOUR_MODEL_PATH} --dataset_name=long_nq_dedup --dataset_split=test_answered --task_answer=newline --task_context=only_gold_long --post_cleanup=sept2dec --model_max_length 200 --model_output_path fid-tmp --dataset_num_shards=100 --ds_config=./scripts/config/deepspeed_inference.json
deepspeed --num_gpus 1 entrypoint.py --module qa_evaluation --task_format fid --model_path {$YOUR_MODEL_PATH} --dataset_name=long_nq_dedup --dataset_split=test_answered --task_answer=newline --task_context=only_gold_long --post_cleanup=sept2dec --model_max_length 200 --model_output_path fid-tmp --dataset_num_shards=100 --ds_config=./scripts/config/deepspeed_inference.json --aggregate

This code works with the official pretrained models, base and large, available here. Follow the instructions to download them.

Metric Computation

Once you have an inference dataset generated and saved by applying a model to your test data, you may now use it to compute some metrics. We have two classes of metrics: performance and faithfulness metrics. As a quick start, run the following to compute the k-precision and rouge score for your dataset.

python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH}  --score_dir {$SCORE_SAVE_PATH}  --rouge --kprecision

This will compute the rouge and k-precision scores for the given data and save the results in $SCORE_SAVE_PATH.

Where are the scores stored?

The scores are saved inside score_dir in a single json file whose name starts with eval_scores_{model_info}. If no score_dir is specified, then the score json will be stored in the same directory as the dataset, in inf_dataset_path. If the --store_individual_scores option is set, then the score of each sample for the given metric will also be saved inside a subdirectory named after the metric name.

Compute all metrics

To compute all possible metrics (both faithfulness metrics and inference metrics) EXCEPT the llm-based ones, you may run:

python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --score_dir {$SCORE_SAVE_PATH}  --all_metrics --all_faith_metrics

Equivalently, the switch --acl_metrics will compute all the metrics reported in our [https://aclanthology.org/2024.findings-emnlp.896/](EMNLP paper).

Postprocess the answers

It is sometimes useful to postprocess the generated answers to remove EOS characters, ...etc. For this, use the answer_processing flag. Options are postproc_X or postproc_tulu2.

python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH}  --dataset_name nq --all_metrics --answer_processing postproc_X

Filter samples by dataset

It is possible to filter the inference dataset by dataset name for computing metrics on only the given dataset. For this, use the dataset_name flag. The options are nq, hotpotqa, topiocqa.

python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH}  --dataset_name nq 

Filter samples by index

It is possible to select a subset of the data samples by specifying their index. For this, save the list of indices in a text file (one on each line) -- for example sample_list.txt, and use the dataset_filter_idx_file option to pass the path to sample_list.txt. The text file will be loaded as an array and used to filter the data samples before metric computation.

python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH}  --dataset_name nq --dataset_filter_idx_file sample_list.txt

Best-score evaluation

For samples that have multiple annotations for the same question/context pairs, it is possible to run metric computation by comparing the predicted answer to all possible GT answers and taking the max (this is how NQ is evaluated in most papers). To run such an evaluation, use the best_answer switch.

python3 entrypoint.py --module qa_compute_metrics --inf_dataset_path {$YOUR_INFERENCE_DATASET_PATH} --dataset_name nq --best_score

The script will first load the dataset, group by question/context, and evaluate by taking the best score for each sample.

About

Cross-Attending to Cached Context for Efficient LLM Inference

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages