-
Notifications
You must be signed in to change notification settings - Fork 711
[feat] support fine-tuning of reranker models #4671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5fdb749
support reranker training
0russwest0 f284d8b
Merge branch 'main' into support_reranker
0russwest0 0bf5d9d
Enhance Reranker support with new training scripts and loss functions…
0russwest0 2708fcb
Merge branch 'main' into support_reranker
0russwest0 692bd21
pre-commit fix
0russwest0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Reranker训练 | ||
|
||
SWIFT已经支持Reranker模型的训练,目前已经支持的模型有: | ||
|
||
1. modernbert reranker模型 | ||
- [ModelScope](https://www.modelscope.cn/models/iic/gte-reranker-modernbert-base) [Hugging Face](https://huggingface.co/Alibaba-NLP/gte-reranker-modernbert-base) | ||
2. qwen3-reranker模型 | ||
- 0.6B: [ModelScope](https://www.modelscope.cn/models/Qwen/Qwen3-Reranker-0.6B) [Hugging Face](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B) | ||
- 4B: [ModelScope](https://www.modelscope.cn/models/Qwen/Qwen3-Reranker-4B) [Hugging Face](https://huggingface.co/Qwen/Qwen3-Reranker-4B) | ||
- 8B: [ModelScope](https://www.modelscope.cn/models/Qwen/Qwen3-Reranker-8B) [Hugging Face](https://huggingface.co/Qwen/Qwen3-Reranker-8B) | ||
|
||
## 实现方式 | ||
|
||
目前SWIFT支持两种Reranker模型的实现方式,二者在架构和损失函数计算上有显著差异: | ||
|
||
### 1. 分类式Reranker(Classification Reranker) | ||
|
||
**适用模型:** modernbert reranker模型(如gte-reranker-modernbert-base) | ||
|
||
**核心原理:** | ||
- 基于序列分类架构,在预训练模型基础上添加分类头 | ||
- 输入:query-document对,输出:单个相关性分数 | ||
|
||
|
||
### 2. 生成式Reranker(Generative Reranker) | ||
|
||
**适用模型:** qwen3-reranker模型(0.6B/4B/8B) | ||
|
||
**核心原理:** | ||
- 基于生成式语言模型架构(CausalLM) | ||
- 输入:query-document对,输出:特定token的概率(如"yes"/"no") | ||
- 通过对比最后位置特定token的logits进行分类 | ||
|
||
## 损失函数类型 | ||
|
||
SWIFT支持多种损失函数来训练Reranker模型: | ||
|
||
### Pointwise损失函数 | ||
Pointwise方法将排序问题转化为二分类问题,独立处理每个query-document对: | ||
|
||
- **核心思想:** 对每个query-document对进行二分类,判断文档是否与查询相关 | ||
- **损失函数:** 二分类交叉熵 | ||
- **适用场景:** 简单高效,适合大规模数据训练 | ||
|
||
环境变量配置: | ||
- `GENERATIVE_RERANKER_POSITIVE_TOKEN`:正例token(默认:"yes") | ||
- `GENERATIVE_RERANKER_NEGATIVE_TOKEN`:负例token(默认:"no") | ||
|
||
### Listwise损失函数 | ||
Listwise方法将排序问题转化为多分类问题,从多个候选文档中选择正例: | ||
|
||
- **核心思想:** 对每个query的候选文档组(1个正例 + n个负例)进行多分类,识别正例文档 | ||
- **损失函数:** 多分类交叉熵 | ||
- **适用场景:** 学习文档间的相对排序关系,更符合信息检索的实际需求 | ||
|
||
环境变量配置: | ||
- `LISTWISE_RERANKER_TEMPERATURE`:softmax温度参数(默认:1.0) | ||
- `LISTWISE_RERANKER_MIN_GROUP_SIZE`:最小组大小(默认:2) | ||
- `LISTWISE_GENERATIVE_RERANKER_TEMPERATURE`:listwise温度参数(默认:1.0) | ||
- `LISTWISE_GENERATIVE_RERANKER_MIN_GROUP_SIZE`:最小组大小(默认:2) | ||
|
||
**Listwise vs Pointwise:** | ||
- **Pointwise:** 独立判断相关性,训练简单,但忽略了文档间的相对关系 | ||
- **Listwise:** 学习相对排序,性能更优,更适合排序任务的本质需求 | ||
|
||
## 评估指标 | ||
|
||
SWIFT为Reranker训练提供了专业的信息检索评估指标: | ||
|
||
### MRR (Mean Reciprocal Rank) | ||
- **定义:** 所有查询的倒数排名的平均值 | ||
- **计算方式:** MRR = (1/|Q|) × Σ(1/rank_i),其中rank_i是第i个查询的正例文档排名 | ||
- **取值范围:** [0, 1],越大越好 | ||
- **适用场景:** 关注正例文档在排序结果中的位置 | ||
|
||
### NDCG (Normalized Discounted Cumulative Gain) | ||
- **定义:** 标准化折扣累积增益 | ||
- **计算方式:** NDCG = DCG / IDCG,考虑了排序位置对相关性的影响 | ||
- **取值范围:** [0, 1],越大越好 | ||
- **适用场景:** 综合评估排序质量,对top位置的相关性更敏感 | ||
|
||
**指标计算说明:** | ||
- 指标基于query分组计算,每个query组以正例文档开始,后跟负例文档 | ||
- 数据格式:`[1,0,0,1,0,0,0]` 表示2个query:query1=[1,0,0],query2=[1,0,0,0] | ||
- 自动识别query边界并分别计算每个query的指标,最后取平均值 | ||
|
||
loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/loss.py)找到。 | ||
|
||
## 数据集格式 | ||
|
||
```json lines | ||
{"query": "query", "positive": ["relevant_doc1", "relevant_doc2", ...], "negative": ["irrelevant_doc1", "irrelevant_doc2", ...]} | ||
``` | ||
|
||
> 参考[MTEB/scidocs-reranking](https://www.modelscope.cn/datasets/MTEB/scidocs-reranking) | ||
|
||
## 脚手架 | ||
|
||
SWIFT提供了两个脚手架训练脚本: | ||
|
||
- [Pointwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh) | ||
- [Pointwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh) | ||
- [Listwise分类式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) | ||
- [Listwise生成式Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Reranker Training | ||
|
||
SWIFT supports Reranker model training. Currently supported models include: | ||
|
||
1. modernbert reranker model | ||
- [ModelScope](https://www.modelscope.cn/models/iic/gte-reranker-modernbert-base) [Hugging Face](https://huggingface.co/Alibaba-NLP/gte-reranker-modernbert-base) | ||
2. qwen3-reranker model | ||
- 0.6B: [ModelScope](https://www.modelscope.cn/models/Qwen/Qwen3-Reranker-0.6B) [Hugging Face](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B) | ||
- 4B: [ModelScope](https://www.modelscope.cn/models/Qwen/Qwen3-Reranker-4B) [Hugging Face](https://huggingface.co/Qwen/Qwen3-Reranker-4B) | ||
- 8B: [ModelScope](https://www.modelscope.cn/models/Qwen/Qwen3-Reranker-8B) [Hugging Face](https://huggingface.co/Qwen/Qwen3-Reranker-8B) | ||
|
||
## Implementation Methods | ||
|
||
SWIFT currently supports two implementation methods for Reranker models, which have significant differences in architecture and loss function computation: | ||
|
||
### 1. Classification Reranker | ||
|
||
**Applicable Models:** modernbert reranker models (e.g., gte-reranker-modernbert-base) | ||
|
||
**Core Principles:** | ||
- Based on sequence classification architecture, adding a classification head on top of pre-trained models | ||
- Input: query-document pairs, Output: single relevance score | ||
|
||
### 2. Generative Reranker | ||
|
||
**Applicable Models:** qwen3-reranker models (0.6B/4B/8B) | ||
|
||
**Core Principles:** | ||
- Based on generative language model architecture (CausalLM) | ||
- Input: query-document pairs, Output: probability of specific tokens (e.g., "yes"/"no") | ||
- Classification is performed by comparing logits of specific tokens at the final position | ||
|
||
## Loss Function Types | ||
|
||
SWIFT supports multiple loss functions for training Reranker models: | ||
|
||
### Pointwise Loss Functions | ||
Pointwise methods transform the ranking problem into a binary classification problem, processing each query-document pair independently: | ||
|
||
- **Core Idea:** Binary classification for each query-document pair to determine document relevance to the query | ||
- **Loss Function:** Binary cross-entropy | ||
- **Use Cases:** Simple and efficient, suitable for large-scale data training | ||
|
||
Environment variable configuration: | ||
- `GENERATIVE_RERANKER_POSITIVE_TOKEN`: Positive token (default: "yes") | ||
- `GENERATIVE_RERANKER_NEGATIVE_TOKEN`: Negative token (default: "no") | ||
|
||
### Listwise Loss Functions | ||
Listwise methods transform the ranking problem into a multi-classification problem, selecting positive examples from multiple candidate documents: | ||
|
||
- **Core Idea:** Multi-classification for each query's candidate document group (1 positive + n negative examples) to identify positive documents | ||
- **Loss Function:** Multi-class cross-entropy | ||
- **Use Cases:** Learning relative ranking relationships between documents, better aligned with the actual needs of information retrieval | ||
|
||
Environment variable configuration: | ||
- `LISTWISE_RERANKER_TEMPERATURE`: Softmax temperature parameter (default: 1.0) | ||
- `LISTWISE_RERANKER_MIN_GROUP_SIZE`: Minimum group size (default: 2) | ||
- `LISTWISE_GENERATIVE_RERANKER_TEMPERATURE`: Listwise temperature parameter (default: 1.0) | ||
- `LISTWISE_GENERATIVE_RERANKER_MIN_GROUP_SIZE`: Minimum group size (default: 2) | ||
|
||
**Listwise vs Pointwise:** | ||
- **Pointwise:** Independent relevance judgment, simple training, but ignores relative relationships between documents | ||
- **Listwise:** Learning relative ranking, better performance, more suitable for the essential needs of ranking tasks | ||
|
||
## Evaluation Metrics | ||
|
||
SWIFT provides professional information retrieval evaluation metrics for Reranker training: | ||
|
||
### MRR (Mean Reciprocal Rank) | ||
- **Definition:** Average of reciprocal ranks across all queries | ||
- **Calculation:** MRR = (1/|Q|) × Σ(1/rank_i), where rank_i is the rank of the positive document for the i-th query | ||
- **Range:** [0, 1], higher is better | ||
- **Use Cases:** Focus on the position of positive documents in ranking results | ||
|
||
### NDCG (Normalized Discounted Cumulative Gain) | ||
- **Definition:** Normalized discounted cumulative gain | ||
- **Calculation:** NDCG = DCG / IDCG, considering the impact of ranking position on relevance | ||
- **Range:** [0, 1], higher is better | ||
- **Use Cases:** Comprehensive evaluation of ranking quality, more sensitive to relevance at top positions | ||
|
||
**Metric Calculation Notes:** | ||
- Metrics are calculated based on query grouping, with each query group starting with a positive document followed by negative documents | ||
- Data format: `[1,0,0,1,0,0,0]` represents 2 queries: query1=[1,0,0], query2=[1,0,0,0] | ||
- Automatically identifies query boundaries and calculates metrics for each query separately, then takes the average | ||
|
||
The loss function source code can be found [here](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/loss.py). | ||
|
||
## Dataset Format | ||
|
||
```json lines | ||
{"query": "query", "positive": ["relevant_doc1", "relevant_doc2", ...], "negative": ["irrelevant_doc1", "irrelevant_doc2", ...]} | ||
``` | ||
|
||
> Reference: [MTEB/scidocs-reranking](https://www.modelscope.cn/datasets/MTEB/scidocs-reranking) | ||
|
||
## Training Scripts | ||
|
||
SWIFT provides four training script templates: | ||
|
||
- [Pointwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker.sh) | ||
- [Pointwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker.sh) | ||
- [Listwise Classification Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_reranker_listwise.sh) | ||
- [Listwise Generative Reranker](https://github.com/tastelikefeet/swift/blob/main/examples/train/reranker/train_generative_reranker_listwise.sh) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
中文readme加一下