浅浅尝试多尺度Loss优化Qwen-Reranker效果

实验探索:浅浅尝试多尺度Loss优化Qwen-Reranker效果

一、背景与动机

在 RAG(检索增强生成)系统中,Reranker 扮演着精排的角色,对检索结果进行二次排序。传统的 Reranker 训练通常采用单一的 Binary Cross Entropy Loss,即PointWiseLoss,或者ListWiseLoss,但在实际场景中,我们往往面临两个挑战:

  1. 正负样本不平衡:正样本(相关文档)通常远少于负样本
  2. 排序一致性:我们不仅希望模型正确分类,更希望同一 query 下的相关文档得分高于不相关文档
  3. 局部与全局问题:很多时候每个Batch的局部最优拉到全局并不一定最优,尤其在数据质量不够高且数据量很大时候

本文尝试通过多尺度 Loss 设计来同时解决这两个问题。

二、核心实现

2.1 整体架构

基于 HuggingFace Trainer 实现自定义训练器,支持四种 Loss 模式:

Loss 类型 用途 特点
pointwise 基础分类 标准 Cross Entropy
focal 样本不平衡 自动降权简单样本
listwise 排序一致性 Batch 内组间排序
global_consistent 综合方案 Focal + Listwise 组合

2.2 训练器实现

class RerankerTrainer(Trainer):
def __init__(
self,
yes_token_id,
no_token_id,
loss_type="pointwise",
temperature=0.05,
focal_alpha=0.25,
focal_gamma=2.0,
listwise_weight=0.1,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.yes_token_id = yes_token_id
self.no_token_id = no_token_id
self.loss_type = loss_type
self.temperature = temperature
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
self.listwise_weight = listwise_weight
self._step_accuracy_sum = 0.0
self._step_accuracy_count = 0

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels") # [batch]

# 1. Forward pass
outputs = model(**inputs)
logits = outputs.logits # [batch, seq_len, vocab_size]

# 2. 提取分类用的 Logits (基于 "yes"/"no" token)
last_logits = logits[:, -1, :] # 取最后一个 token
yes_logits = last_logits[:, self.yes_token_id]
no_logits = last_logits[:, self.no_token_id]

# 构造二分类 logits: [batch, 2]
# index 0 为 'no'(非同款), index 1 为 'yes'(同款)
binary_logits = torch.stack([no_logits, yes_logits], dim=1)

# 3. 根据 loss_type 计算 loss
if self.loss_type == "pointwise":
loss = self._pointwise_loss(binary_logits, labels)
elif self.loss_type == "focal":
loss = self._focal_loss(binary_logits, labels)
elif self.loss_type == "listwise":
loss = self._listwise_loss(binary_logits, labels)
elif self.loss_type == "global_consistent":
loss = self._global_consistent_loss(binary_logits, labels)
else:
loss = self._pointwise_loss(binary_logits, labels)

# 4. 计算准确率并记录
with torch.no_grad():
preds = torch.argmax(binary_logits, dim=1)
correct = (preds == labels).sum().item()
accuracy = correct / labels.size(0)

self._step_accuracy_sum += accuracy
self._step_accuracy_count += 1

if self.state.global_step % self.args.logging_steps == 0 and self._step_accuracy_count > 0:
avg_accuracy = self._step_accuracy_sum / self._step_accuracy_count
self.log({"train_accuracy": avg_accuracy})
self._step_accuracy_sum = 0.0
self._step_accuracy_count = 0

return (loss, outputs) if return_outputs else loss

def _pointwise_loss(self, binary_logits, labels):
"""标准 Cross Entropy Loss"""
return F.cross_entropy(binary_logits, labels)

def _focal_loss(self, binary_logits, labels):
"""Focal Loss:处理样本不平衡"""
ce_loss = F.cross_entropy(binary_logits, labels, reduction='none')
pt = torch.exp(-ce_loss) # 预测正确的概率
focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss
return focal_loss.mean()

def _listwise_loss(self, binary_logits, labels):
"""Listwise Loss:处理 Batch 内排序一致性"""
loss_listwise = torch.tensor(0.0, device=binary_logits.device)
pos_indices = torch.nonzero(labels == 1).squeeze(-1)
group_count = 0

if len(pos_indices) > 0:
for idx in pos_indices:
# 找到当前组的范围(假设数据按组排列:正样本在前,后面是负样本)
start = idx.item()
# 寻找下一组的开始
next_pos = torch.nonzero(labels[start+1:] == 1)
end = start + 1 + next_pos[0].item() if len(next_pos) > 0 else labels.size(0)

group_logits = binary_logits[start:end, 1] # 只取 'yes' 的得分进行组内对比

if group_logits.size(0) > 1:
# Listwise 目标:组内第一个(正样本)得分最高
target = torch.tensor([0], device=binary_logits.device)
loss_listwise = loss_listwise + F.cross_entropy(
(group_logits / self.temperature).unsqueeze(0), target
)
group_count += 1

if group_count > 0:
return loss_listwise / group_count
else:
# Fallback to pointwise
return F.cross_entropy(binary_logits, labels)

def _global_consistent_loss(self, binary_logits, labels):
"""全局一致性 Loss:Focal + Listwise"""
# Loss A: Focal Loss (处理正负样本失衡)
loss_pointwise = self._focal_loss(binary_logits, labels)

# Loss B: Listwise Loss (处理 Batch 内排序一致性)
loss_listwise = self._listwise_loss(binary_logits, labels)

return loss_pointwise + self.listwise_weight * loss_listwise

三、Loss 设计思路

3.1 Focal Loss:应对样本不平衡

Focal Loss 通过动态调整样本权重,让模型更关注"难分类"的样本:

FL(pt)=αt(1pt)γlog(pt)\text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)

  • α\alpha:正样本权重,平衡正负样本比例
  • γ\gamma:聚焦参数,控制难易样本的权重差异

3.2 Listwise Loss:优化排序一致性

传统 pointwise loss 只关心单个样本的分类正确性,但 Reranker 的核心目标是排序。Listwise Loss 确保:

  • 同一 query 下,正样本得分 > 负样本得分
  • 通过 temperature 参数控制排序 margin 的锐度

3.3 Global Consistent Loss:鱼和熊掌兼得

组合 Focal + Listwise,同时优化分类准确率和排序一致性:

Lglobal=Lfocal+λLlistwiseL_{global} = L_{focal} + \lambda \cdot L_{listwise}

其中 λ\lambda 控制 listwise 项的权重。

四、使用建议

  1. 数据组织:确保 batch 内数据按 (query, positive_doc, negative_docs...) 排列
  2. 超参调优
    • focal_alpha=0.25, focal_gamma=2.0 是经典配置
    • listwise_weight=0.1 起,根据排序指标调整
    • temperature 越小排序越激进,但可能不稳定

五、总结

这次尝试的核心收获:

  • 多尺度 Loss 设计能有效兼顾不同训练目标
  • Focal Loss 对样本不平衡场景有显著帮助
  • Listwise Loss 提升了排序一致性,但依赖数据组织方式

后续计划:在 BEIR 等基准上系统对比各 Loss 的效果差异。