实验探索:浅浅尝试多尺度Loss优化Qwen-Reranker效果
一、背景与动机
在 RAG(检索增强生成)系统中,Reranker 扮演着精排的角色,对检索结果进行二次排序。传统的 Reranker 训练通常采用单一的 Binary Cross Entropy Loss,即PointWiseLoss,或者ListWiseLoss,但在实际场景中,我们往往面临两个挑战:
- 正负样本不平衡:正样本(相关文档)通常远少于负样本
- 排序一致性:我们不仅希望模型正确分类,更希望同一 query 下的相关文档得分高于不相关文档
- 局部与全局问题:很多时候每个Batch的局部最优拉到全局并不一定最优,尤其在数据质量不够高且数据量很大时候
本文尝试通过多尺度 Loss 设计来同时解决这两个问题。
二、核心实现
2.1 整体架构
基于 HuggingFace Trainer 实现自定义训练器,支持四种 Loss 模式:
| Loss 类型 |
用途 |
特点 |
pointwise |
基础分类 |
标准 Cross Entropy |
focal |
样本不平衡 |
自动降权简单样本 |
listwise |
排序一致性 |
Batch 内组间排序 |
global_consistent |
综合方案 |
Focal + Listwise 组合 |
2.2 训练器实现
class RerankerTrainer(Trainer): def __init__( self, yes_token_id, no_token_id, loss_type="pointwise", temperature=0.05, focal_alpha=0.25, focal_gamma=2.0, listwise_weight=0.1, *args, **kwargs ): super().__init__(*args, **kwargs) self.yes_token_id = yes_token_id self.no_token_id = no_token_id self.loss_type = loss_type self.temperature = temperature self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma self.listwise_weight = listwise_weight self._step_accuracy_sum = 0.0 self._step_accuracy_count = 0
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels")
outputs = model(**inputs) logits = outputs.logits
last_logits = logits[:, -1, :] yes_logits = last_logits[:, self.yes_token_id] no_logits = last_logits[:, self.no_token_id]
binary_logits = torch.stack([no_logits, yes_logits], dim=1)
if self.loss_type == "pointwise": loss = self._pointwise_loss(binary_logits, labels) elif self.loss_type == "focal": loss = self._focal_loss(binary_logits, labels) elif self.loss_type == "listwise": loss = self._listwise_loss(binary_logits, labels) elif self.loss_type == "global_consistent": loss = self._global_consistent_loss(binary_logits, labels) else: loss = self._pointwise_loss(binary_logits, labels)
with torch.no_grad(): preds = torch.argmax(binary_logits, dim=1) correct = (preds == labels).sum().item() accuracy = correct / labels.size(0)
self._step_accuracy_sum += accuracy self._step_accuracy_count += 1
if self.state.global_step % self.args.logging_steps == 0 and self._step_accuracy_count > 0: avg_accuracy = self._step_accuracy_sum / self._step_accuracy_count self.log({"train_accuracy": avg_accuracy}) self._step_accuracy_sum = 0.0 self._step_accuracy_count = 0
return (loss, outputs) if return_outputs else loss
def _pointwise_loss(self, binary_logits, labels): """标准 Cross Entropy Loss""" return F.cross_entropy(binary_logits, labels)
def _focal_loss(self, binary_logits, labels): """Focal Loss:处理样本不平衡""" ce_loss = F.cross_entropy(binary_logits, labels, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss return focal_loss.mean()
def _listwise_loss(self, binary_logits, labels): """Listwise Loss:处理 Batch 内排序一致性""" loss_listwise = torch.tensor(0.0, device=binary_logits.device) pos_indices = torch.nonzero(labels == 1).squeeze(-1) group_count = 0
if len(pos_indices) > 0: for idx in pos_indices: start = idx.item() next_pos = torch.nonzero(labels[start+1:] == 1) end = start + 1 + next_pos[0].item() if len(next_pos) > 0 else labels.size(0)
group_logits = binary_logits[start:end, 1]
if group_logits.size(0) > 1: target = torch.tensor([0], device=binary_logits.device) loss_listwise = loss_listwise + F.cross_entropy( (group_logits / self.temperature).unsqueeze(0), target ) group_count += 1
if group_count > 0: return loss_listwise / group_count else: return F.cross_entropy(binary_logits, labels)
def _global_consistent_loss(self, binary_logits, labels): """全局一致性 Loss:Focal + Listwise""" loss_pointwise = self._focal_loss(binary_logits, labels)
loss_listwise = self._listwise_loss(binary_logits, labels)
return loss_pointwise + self.listwise_weight * loss_listwise
|
三、Loss 设计思路
3.1 Focal Loss:应对样本不平衡
Focal Loss 通过动态调整样本权重,让模型更关注"难分类"的样本:
FL(pt)=−αt(1−pt)γlog(pt)
- α:正样本权重,平衡正负样本比例
- γ:聚焦参数,控制难易样本的权重差异
3.2 Listwise Loss:优化排序一致性
传统 pointwise loss 只关心单个样本的分类正确性,但 Reranker 的核心目标是排序。Listwise Loss 确保:
- 同一 query 下,正样本得分 > 负样本得分
- 通过 temperature 参数控制排序 margin 的锐度
3.3 Global Consistent Loss:鱼和熊掌兼得
组合 Focal + Listwise,同时优化分类准确率和排序一致性:
Lglobal=Lfocal+λ⋅Llistwise
其中 λ 控制 listwise 项的权重。
四、使用建议
- 数据组织:确保 batch 内数据按
(query, positive_doc, negative_docs...) 排列
- 超参调优:
focal_alpha=0.25, focal_gamma=2.0 是经典配置
listwise_weight=0.1 起,根据排序指标调整
temperature 越小排序越激进,但可能不稳定
五、总结
这次尝试的核心收获:
- 多尺度 Loss 设计能有效兼顾不同训练目标
- Focal Loss 对样本不平衡场景有显著帮助
- Listwise Loss 提升了排序一致性,但依赖数据组织方式
后续计划:在 BEIR 等基准上系统对比各 Loss 的效果差异。


Chasing
A record of Life and Work
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Chasing's BLOG!