LLM 训练方法实现:SFT、GRPO、DAPO 与 On-Policy Distillation

本文记录了几种主流 LLM 微调框架与范式的实现方式,包括监督微调(SFT)、组相对策略优化(GRPO)、动态优势策略优化(DAPO)以及在线策略蒸馏(On-Policy Distillation)。

LLM 训练方法实现:SFT、GRPO、DAPO 与 On-Policy Distillation

一、引言

随着大语言模型的发展,如何高效地微调和优化模型成为研究热点。本文将系统性地介绍四种核心训练范式:

  • SFT (Supervised Fine-Tuning):监督微调,最基础的指令对齐方法
  • GRPO (Group Relative Policy Optimization):组相对策略优化,一种高效强化学习微调方法
  • DAPO (Dynamic Advantage Policy Optimization):动态优势策略优化,对传统 RL 方法的改进
  • On-Policy Distillation:在线策略蒸馏,结合蒸馏与策略优化的训练范式

二、SFT:监督微调

监督微调是 LLM 对齐的基础方法,通过在标注数据上进行下一词预测训练,使模型学习期望的输出格式和风格。

#!/usr/bin/env python3
"""
SFT (Supervised Fine-Tuning) Training Script
for Qwen3.5-0.8B Multimodal Model

Usage:
Single GPU: python scripts/run_sft.py
Multi GPU with DeepSpeed: deepspeed --num_gpus=2 scripts/run_sft.py --deepspeed configs/ds_config_zero2.json

Reference:
- TRL SFTTrainer: https://huggingface.co/docs/trl/sft_trainer
- Qwen2-VL Training: https://github.com/QwenLM/Qwen2-VL
"""

import os
import argparse
import json
from pathlib import Path
from typing import Optional, List, Dict

import torch
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
)
from datasets import load_dataset, Dataset
from trl import SFTTrainer, SFTConfig
from PIL import Image


# ============================================================================
# Configuration
# ============================================================================

MODEL_PATH = "/workspace/models/Qwen3.5-0.8B"
OUTPUT_DIR = "/workspace/train_codes/output/sft"
DATA_DIR = "/workspace/train_codes/data"

DEFAULT_DATASET = "tatsu-lab/alpaca"


# ============================================================================
# Data Processing Functions
# ============================================================================

def format_alpaca_example(example: Dict) -> str:
"""
Format Alpaca dataset example into chat format.
"""
instruction = example.get("instruction", "")
input_text = example.get("input", "")
output = example.get("output", "")

if input_text:
prompt = f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
else:
prompt = f"### Instruction:\n{instruction}\n\n### Response:\n"

full_text = prompt + output
return full_text


def format_sharegpt_example(example: Dict) -> str:
"""
Format ShareGPT/LLaVA style conversation format.
"""
conversations = example.get("conversations", [])
formatted_text = ""

for conv in conversations:
role = conv.get("from", "")
content = conv.get("value", "")

if role == "human":
formatted_text += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "gpt":
formatted_text += f"<|im_start|>assistant\n{content}<|im_end|>\n"

return formatted_text


def format_multimodal_example(example: Dict, processor) -> str:
"""
Format multimodal example for Qwen3.5-VL model.
"""
conversations = example.get("conversations", [])
image_path = example.get("image", None)

messages = []

for conv in conversations:
role = conv.get("from", "")
content = conv.get("value", "")

if role == "human":
if "<image>" in content and image_path:
user_content = [
{"type": "image", "image": image_path},
{"type": "text", "text": content.replace("<image>", "").strip()}
]
else:
user_content = [{"type": "text", "text": content}]
messages.append({"role": "user", "content": user_content})
elif role == "gpt":
messages.append({"role": "assistant", "content": content})

if processor:
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
else:
text = ""
for msg in messages:
if msg["role"] == "user":
content_str = " ".join([c.get("text", "") if isinstance(c, dict) else str(c) for c in msg["content"]])
text += f"<|im_start|>user\n{content_str}<|im_end|>\n"
else:
text += f"<|im_start|>assistant\n{msg['content']}<|im_end|>\n"

return text


# ============================================================================
# Dataset Loading
# ============================================================================

def load_training_dataset(dataset_name: str, max_samples: int, use_multimodal: bool = False) -> Dataset:
"""
Load and prepare the training dataset.
"""
print(f"Loading dataset: {dataset_name}")

try:
if use_multimodal:
dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
else:
dataset = load_dataset(dataset_name, split="train")
except Exception as e:
print(f"Failed to load from HuggingFace: {e}")
local_path = Path(DATA_DIR) / f"{dataset_name}.json"
if local_path.exists():
with open(local_path, "r") as f:
data = json.load(f)
dataset = Dataset.from_list(data)
else:
raise ValueError(f"Dataset not found: {dataset_name}")

if max_samples > 0 and len(dataset) > max_samples:
dataset = dataset.select(range(max_samples))

print(f"Dataset loaded: {len(dataset)} samples")
return dataset


# ============================================================================
# Main Training Function
# ============================================================================

def main():
parser = argparse.ArgumentParser(description="SFT Training Script")
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR)
parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATASET)
parser.add_argument("--max_samples", type=int, default=1000)
parser.add_argument("--max_length", type=int, default=2048)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--num_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--warmup_ratio", type=float, default=0.03)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--deepspeed", type=str, default=None)
parser.add_argument("--use_multimodal", action="store_true")
parser.add_argument("--use_lora", action="store_true")
parser.add_argument("--lora_r", type=int, default=8)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--logging_steps", type=int, default=10)
parser.add_argument("--save_steps", type=int, default=100)
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()

print("=" * 60)
print("SFT Training Configuration")
print("=" * 60)
print(f"Model: {args.model_path}")
print(f"Output: {args.output_dir}")
print(f"Dataset: {args.dataset_name}")
print(f"Max samples: {args.max_samples}")
print(f"Max length: {args.max_length}")
print(f"Learning rate: {args.learning_rate}")
print(f"Epochs: {args.num_epochs}")
print(f"Batch size: {args.batch_size}")
print(f"Gradient accumulation: {args.gradient_accumulation_steps}")
print(f"DeepSpeed: {args.deepspeed}")
print(f"Use LoRA: {args.use_lora}")
print("=" * 60)

# Set seed
torch.manual_seed(args.seed)

# Load model and tokenizer
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto" if args.deepspeed is None else None,
)

tokenizer = AutoTokenizer.from_pretrained(
args.model_path,
trust_remote_code=True,
)

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Load processor for multimodal
processor = None
if args.use_multimodal:
try:
processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
except Exception as e:
print(f"Warning: Could not load processor: {e}")

# Apply LoRA if specified
if args.use_lora:
from peft import LoraConfig, get_peft_model
print("Applying LoRA...")
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Load dataset
dataset = load_training_dataset(args.dataset_name, args.max_samples, args.use_multimodal)

# Define formatting function
if args.use_multimodal:
formatting_func = lambda ex: format_multimodal_example(ex, processor)
elif "alpaca" in args.dataset_name.lower():
formatting_func = format_alpaca_example
else:
formatting_func = format_sharegpt_example

# Training configuration
training_args = SFTConfig(
output_dir=args.output_dir,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
learning_rate=args.learning_rate,
max_length=args.max_length,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
save_total_limit=3,
bf16=True,
deepspeed=args.deepspeed,
gradient_checkpointing=True,
optim="adamw_torch",
seed=args.seed,
report_to="none",
packing=False,
)

# Create trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
formatting_func=formatting_func,
processing_class=tokenizer,
)

# Start training
print("Starting training...")
trainer.train()

# Save final model
print("Saving model...")
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)

print(f"Training complete! Model saved to: {args.output_dir}")


if __name__ == "__main__":
main()

三、GRPO:组相对策略优化

GRPO 是 DeepSeek 提出的一种高效强化学习微调方法,通过组内相对比较来计算优势值,避免了传统 PPO 中需要额外训练价值模型的计算开销。

#!/usr/bin/env python3
"""
GRPO (Group Relative Policy Optimization) Training Script - CORRECTED VERSION
for Qwen3.5-0.8B Model

GRPO is proposed by DeepSeek in DeepSeekMath and DeepSeek-R1 papers.

Key corrections from original implementation:
1. ratio = exp(policy - old_policy), NOT policy - ref
2. old_policy log probs must be saved during sampling (detached)
3. Proper per-token log prob computation with length normalization
4. Correct KL divergence estimator (k3 approximation)

Reference:
- DeepSeekMath Paper: https://arxiv.org/abs/2402.03300
- DeepSeek-R1 Paper: https://arxiv.org/abs/2501.12948
- OpenRLHF: https://github.com/OpenRLHF/OpenRLHF
"""

import os
import argparse
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)
from datasets import load_dataset, Dataset
from accelerate import Accelerator
from tqdm import tqdm


# ============================================================================
# Configuration
# ============================================================================

MODEL_PATH = "/workspace/models/Qwen3.5-0.8B"
OUTPUT_DIR = "/workspace/train_codes/output/grpo"
DATA_DIR = "/workspace/train_codes/data"
DEFAULT_DATASET = "tatsu-lab/alpaca"


@dataclass
class GRPOConfig:
"""Configuration for GRPO training."""
model_path: str = MODEL_PATH
output_dir: str = OUTPUT_DIR
dataset_name: str = DEFAULT_DATASET
max_samples: int = 1000
max_prompt_length: int = 512
max_response_length: int = 512

# GRPO specific parameters
group_size: int = 8 # Samples per prompt for advantage estimation
beta: float = 0.1 # KL divergence coefficient
clip_coef: float = 0.2 # PPO clipping coefficient
num_updates_per_sample: int = 1 # Number of policy updates per batch of samples

# Training parameters
learning_rate: float = 1e-6
num_epochs: int = 1
batch_size: int = 4
gradient_accumulation_steps: int = 8
weight_decay: float = 0.01

# Generation parameters
temperature: float = 1.0
top_p: float = 0.95

# Other
deepspeed: Optional[str] = None
seed: int = 42
logging_steps: int = 10
save_steps: int = 100


# ============================================================================
# GRPO Core Algorithm - CORRECTED
# ============================================================================

def compute_grpo_advantage(rewards: torch.Tensor, epsilon: float = 1e-8) -> torch.Tensor:
"""
Compute GRPO advantage using group-relative normalization.

For each prompt, sample G outputs and compute rewards.
Advantage = (r_i - mean(r)) / (std(r) + epsilon)

This eliminates the need for a separate Critic (Value network).
"""
mean = rewards.mean(dim=-1, keepdim=True)
std = rewards.std(dim=-1, keepdim=True)
advantages = (rewards - mean) / (std + epsilon)
return advantages


def compute_per_token_log_probs(
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""
Compute per-token log probabilities for response tokens.

This is the CORRECT way to compute log probs:
1. Forward pass to get logits
2. logits[i] predicts token[i+1]
3. Only compute log prob for response tokens (masked)
4. Return sum of log probs (or mean for length normalization)

Args:
model: Language model
input_ids: Full sequence [batch, seq_len]
attention_mask: Attention mask [batch, seq_len]
response_mask: Mask indicating response tokens [batch, seq_len]

Returns:
log_probs: Per-sample log probabilities [batch]
"""
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1, :] # [batch, seq_len-1, vocab]

# Target tokens (shifted by 1)
target_ids = input_ids[:, 1:] # [batch, seq_len-1]

# Log probabilities with epsilon for numerical stability
log_probs = F.log_softmax(logits, dim=-1)

# Gather log prob for each target token
token_log_probs = torch.gather(
log_probs, dim=-1, index=target_ids.unsqueeze(-1)
).squeeze(-1) # [batch, seq_len-1]

# Apply response mask (only count response tokens)
# Note: response_mask needs to be shifted to match logits positions
response_mask_shifted = response_mask[:, 1:] # [batch, seq_len-1]

# Sum log probs over response tokens
# Length-normalized: divide by number of response tokens
response_lengths = response_mask_shifted.sum(dim=-1).clamp(min=1)
token_log_probs_masked = token_log_probs * response_mask_shifted
log_prob_sum = token_log_probs_masked.sum(dim=-1)

# Return length-normalized log prob
return log_prob_sum / response_lengths


def compute_kl_divergence_k3(
policy_log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
) -> torch.Tensor:
"""
Compute KL divergence using k3 estimator (most accurate).

KL(p||q) ≈ exp(log q - log p) - (log q - log p) - 1

This is more accurate than simple difference estimator.

Args:
policy_log_probs: Log probs from policy (p)
ref_log_probs: Log probs from reference (q)

Returns:
kl: KL divergence estimate
"""
# k3 estimator: exp(log_ref - log_policy) - (log_ref - log_policy) - 1
log_diff = ref_log_probs - policy_log_probs
kl = torch.exp(log_diff) - log_diff - 1
return kl


def compute_grpo_loss_corrected(
policy_log_probs: torch.Tensor,
old_policy_log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
advantages: torch.Tensor,
beta: float = 0.1,
clip_coef: float = 0.2,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute GRPO loss - CORRECTED version.

Key corrections:
1. ratio = exp(policy - old_policy), NOT policy - ref
2. Use proper KL estimator

Loss = -min(ratio * A, clip(ratio) * A) + beta * KL(policy || ref)

Args:
policy_log_probs: Current policy log probs (requires grad)
old_policy_log_probs: Old policy log probs from sampling (detached)
ref_log_probs: Reference model log probs (detached)
advantages: Group-relative advantages
beta: KL coefficient
clip_coef: PPO clip coefficient

Returns:
loss: Total loss
metrics: Dictionary of metrics
"""
# CRITICAL: ratio = exp(policy - old_policy), NOT policy - ref
# This is the importance sampling weight in PPO
log_ratio = policy_log_probs - old_policy_log_probs
ratio = torch.exp(log_ratio)

# PPO clipped surrogate loss
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) * advantages
policy_loss = -torch.min(surr1, surr2).mean()

# KL divergence penalty: KL(policy || ref)
# Use k3 estimator for accuracy
kl = compute_kl_divergence_k3(policy_log_probs, ref_log_probs)
kl_loss = beta * kl.mean()

# Total loss
total_loss = policy_loss + kl_loss

# Metrics
metrics = {
"total_loss": total_loss.item(),
"policy_loss": policy_loss.item(),
"kl_loss": kl_loss.item(),
"ratio_mean": ratio.mean().item(),
"ratio_std": ratio.std().item() if ratio.numel() > 1 else 0.0,
"kl_mean": kl.mean().item(),
}

return total_loss, metrics


# ============================================================================
# Reward Functions
# ============================================================================

class FormatRewardFunction:
"""Reward based on response format and quality."""

def compute(self, prompt: str, response: str) -> float:
reward = 0.0

# Proper ending
if response.strip().endswith('.') or response.strip().endswith('?') or response.strip().endswith('!'):
reward += 0.3

# Reasonable length
if 20 <= len(response) <= 500:
reward += 0.3

# No repetition
words = response.split()
if len(words) > 0:
unique_ratio = len(set(words)) / len(words)
if unique_ratio > 0.7:
reward += 0.2

# Not empty
if len(response.strip()) > 0:
reward += 0.2

return reward


# ============================================================================
# GRPO Trainer - CORRECTED
# ============================================================================

class GRPOTrainer:
"""
GRPO Trainer - Corrected implementation.

Key changes:
1. Save old_policy log probs during sampling
2. Proper per-token log prob computation
3. Correct ratio and KL calculation
"""

def __init__(self, config: GRPOConfig):
self.config = config
self.accelerator = Accelerator()

torch.manual_seed(config.seed)
np.random.seed(config.seed)

# Load policy model
print("Loading policy model...")
self.policy_model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)

# Load reference model (frozen, for KL constraint)
print("Loading reference model...")
self.ref_model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.ref_model.eval()
for param in self.ref_model.parameters():
param.requires_grad = False

# Tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_path, trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# Generation config
self.generation_config = GenerationConfig(
max_new_tokens=config.max_response_length,
temperature=config.temperature,
top_p=config.top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

# Reward function
self.reward_function = FormatRewardFunction()

# Optimizer
self.optimizer = AdamW(
self.policy_model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay,
)

# Prepare with accelerator
self.policy_model, self.optimizer = self.accelerator.prepare(
self.policy_model, self.optimizer
)
self.ref_model.to(self.accelerator.device)

self.global_step = 0

def generate_samples_and_save_old_log_probs(
self,
prompts: List[str],
) -> Tuple[List[str], List[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate samples and save old policy log probs.

CRITICAL: We must compute and save old_policy log probs at sampling time.
These are used for importance sampling ratio and must be detached.

Returns:
prompts_expanded: List of prompts (repeated for each sample)
responses: List of generated responses
input_ids: Tokenized full sequences [batch, seq_len]
attention_mask: Attention masks [batch, seq_len]
response_mask: Masks for response tokens only [batch, seq_len]
old_policy_log_probs: Log probs from old policy (detached)
"""
prompts_expanded = []
responses = []
all_full_texts = []
all_prompt_token_lengths = [] # Actual prompt token count
all_response_token_lengths = [] # Actual response token count

for prompt in prompts:
# Tokenize prompt (fresh for each prompt)
prompt_tokens = self.tokenizer(
prompt, return_tensors="pt", truncation=True,
max_length=self.config.max_prompt_length,
add_special_tokens=False
)
prompt_token_length = prompt_tokens["input_ids"].shape[1]

for _ in range(self.config.group_size):
# Generate response from prompt (NOT continuing from previous)
with torch.no_grad():
gen_inputs = {k: v.to(self.accelerator.device) for k, v in prompt_tokens.items()}
gen_outputs = self.policy_model.generate(
**gen_inputs,
generation_config=self.generation_config,
)

# Extract only newly generated tokens
response_token_ids = gen_outputs[0][prompt_token_length:]
response_token_length = response_token_ids.shape[0]

response = self.tokenizer.decode(
response_token_ids,
skip_special_tokens=True,
)

# Store for batch processing
full_text = prompt + response
all_full_texts.append(full_text)
all_prompt_token_lengths.append(prompt_token_length)
all_response_token_lengths.append(response_token_length)
prompts_expanded.append(prompt)
responses.append(response)

# Batch tokenize with padding
full_tokens = self.tokenizer(
all_full_texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=self.config.max_prompt_length + self.config.max_response_length,
)

input_ids = full_tokens["input_ids"].to(self.accelerator.device)
attention_mask = full_tokens["attention_mask"].to(self.accelerator.device)

# Create response masks - CORRECTED: use actual token lengths and attention_mask
response_mask = torch.zeros_like(input_ids)
for i in range(len(all_full_texts)):
prompt_tok_len = all_prompt_token_lengths[i]
response_tok_len = all_response_token_lengths[i]

# Use attention_mask to find actual sequence end (excludes padding)
actual_seq_len = attention_mask[i].sum().item()

# Response starts after prompt tokens
response_start = prompt_tok_len
# Response ends at actual sequence length or prompt+response length
response_end = min(prompt_tok_len + response_tok_len, actual_seq_len)

if response_end > response_start:
response_mask[i, response_start:response_end] = 1

# Compute old policy log probs (CRITICAL: must be detached and on correct device)
with torch.no_grad():
old_policy_log_probs = compute_per_token_log_probs(
self.policy_model,
input_ids,
attention_mask,
response_mask,
)

# Ensure detached and on device
old_policy_log_probs = old_policy_log_probs.detach().to(self.accelerator.device)

return (
prompts_expanded,
responses,
input_ids,
attention_mask,
response_mask,
old_policy_log_probs,
)

def compute_rewards(self, prompts: List[str], responses: List[str]) -> torch.Tensor:
"""Compute rewards for all samples."""
rewards = []
for prompt, response in zip(prompts, responses):
reward = self.reward_function.compute(prompt, response)
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32, device=self.accelerator.device)

def train_step(self, prompts: List[str]) -> Dict[str, float]:
"""
One GRPO training step - CORRECTED.

Steps:
1. Generate samples and save old_policy log probs
2. Compute rewards and advantages
3. Compute current policy and ref log probs
4. Compute corrected GRPO loss
5. Update policy (with gradient clipping)
"""
# Step 1: Generate samples and save old log probs
(
prompts_expanded,
responses,
input_ids,
attention_mask,
response_mask,
old_policy_log_probs, # CRITICAL: saved from sampling time
) = self.generate_samples_and_save_old_log_probs(prompts)

# Step 2: Compute rewards and advantages
rewards = self.compute_rewards(prompts_expanded, responses)

num_prompts = len(prompts)
rewards_grouped = rewards.view(num_prompts, self.config.group_size)
advantages = compute_grpo_advantage(rewards_grouped)
advantages_flat = advantages.view(-1)

# Step 3: Compute current policy log probs (with grad)
policy_log_probs = compute_per_token_log_probs(
self.policy_model,
input_ids,
attention_mask,
response_mask,
)

# Compute reference log probs (no grad)
with torch.no_grad():
ref_log_probs = compute_per_token_log_probs(
self.ref_model,
input_ids,
attention_mask,
response_mask,
)

# Step 4: Compute corrected GRPO loss
loss, metrics = compute_grpo_loss_corrected(
policy_log_probs,
old_policy_log_probs, # CRITICAL: use old policy, not ref
ref_log_probs,
advantages_flat,
beta=self.config.beta,
clip_coef=self.config.clip_coef,
)

# Check for NaN/Inf
if torch.isnan(loss) or torch.isinf(loss):
print("Warning: NaN or Inf loss detected, skipping update")
self.optimizer.zero_grad()
return {"total_loss": 0.0, "error": "NaN/Inf loss"}

# Step 5: Update policy with gradient clipping
self.accelerator.backward(loss)
self.accelerator.clip_grad_norm_(self.policy_model.parameters(), max_norm=1.0)
self.optimizer.step()
self.optimizer.zero_grad()

self.global_step += 1

# Add reward metrics
metrics["mean_reward"] = rewards.mean().item()
metrics["std_reward"] = rewards.std().item() if rewards.numel() > 1 else 0.0

return metrics

def train(self, dataset: Dataset):
"""Main training loop."""
print(f"Starting GRPO training on {len(dataset)} samples...")
print(f"Group size: {self.config.group_size}")
print(f"Beta (KL coefficient): {self.config.beta}")
print(f"Clip coefficient: {self.config.clip_coef}")

dataloader = DataLoader(
dataset, batch_size=self.config.batch_size, shuffle=True
)

num_batches = len(dataloader) * self.config.num_epochs

for epoch in range(self.config.num_epochs):
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

for batch in tqdm(dataloader, desc="Training"):
# Extract prompts
if "instruction" in batch:
prompts = batch["instruction"]
elif "prompt" in batch:
prompts = batch["prompt"]
else:
prompts = batch["text"]

if isinstance(prompts, torch.Tensor):
prompts = prompts.tolist()

metrics = self.train_step(prompts)

if self.global_step % self.config.logging_steps == 0:
print(
f"Step {self.global_step}/{num_batches}: "
f"loss={metrics['total_loss']:.4f}, "
f"policy_loss={metrics['policy_loss']:.4f}, "
f"kl_loss={metrics['kl_loss']:.4f}, "
f"ratio={metrics['ratio_mean']:.3f}±{metrics['ratio_std']:.3f}, "
f"reward={metrics['mean_reward']:.4f}"
)

if self.global_step % self.config.save_steps == 0:
self.save_checkpoint(self.global_step)

self.save_checkpoint(self.global_step, final=True)
print(f"Training complete! Model saved to: {self.config.output_dir}")

def save_checkpoint(self, step: int, final: bool = False):
"""Save model checkpoint."""
save_dir = (
self.config.output_dir
if final
else os.path.join(self.config.output_dir, f"checkpoint-{step}")
)

unwrapped_model = self.accelerator.unwrap_model(self.policy_model)
unwrapped_model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)
print(f"Checkpoint saved to: {save_dir}")


# ============================================================================
# Dataset Preparation
# ============================================================================

def prepare_dataset(config: GRPOConfig) -> Dataset:
"""Load and prepare dataset."""
print(f"Loading dataset: {config.dataset_name}")

try:
dataset = load_dataset(config.dataset_name, split="train")
except Exception as e:
print(f"Failed to load dataset: {e}")
local_path = Path(DATA_DIR) / f"{config.dataset_name}.json"
if local_path.exists():
with open(local_path, "r") as f:
data = json.load(f)
dataset = Dataset.from_list(data)
else:
raise ValueError(f"Dataset not found: {config.dataset_name}")

if config.max_samples > 0 and len(dataset) > config.max_samples:
dataset = dataset.select(range(config.max_samples))

def format_prompt(example):
if "instruction" in example:
instruction = example["instruction"]
input_text = example.get("input", "")
if input_text:
return {"prompt": f"Instruction: {instruction}\nInput: {input_text}\nResponse:"}
return {"prompt": f"Instruction: {instruction}\nResponse:"}
elif "prompt" in example:
return {"prompt": example["prompt"]}
else:
return {"prompt": str(example)}

dataset = dataset.map(format_prompt)
print(f"Dataset prepared: {len(dataset)} samples")
return dataset


# ============================================================================
# Main
# ============================================================================

def main():
parser = argparse.ArgumentParser(description="GRPO Training Script (Corrected)")
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR)
parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATASET)
parser.add_argument("--max_samples", type=int, default=1000)
parser.add_argument("--group_size", type=int, default=8)
parser.add_argument("--beta", type=float, default=0.1)
parser.add_argument("--clip_coef", type=float, default=0.2)
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--deepspeed", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()

config = GRPOConfig(
model_path=args.model_path,
output_dir=args.output_dir,
dataset_name=args.dataset_name,
max_samples=args.max_samples,
group_size=args.group_size,
beta=args.beta,
clip_coef=args.clip_coef,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
deepspeed=args.deepspeed,
seed=args.seed,
)

print("=" * 60)
print("GRPO Training Configuration (CORRECTED VERSION)")
print("=" * 60)
print(f"Model: {config.model_path}")
print(f"Dataset: {config.dataset_name}")
print(f"Group size: {config.group_size}")
print(f"Beta (KL): {config.beta}")
print(f"Clip coef: {config.clip_coef}")
print("=" * 60)

dataset = prepare_dataset(config)
trainer = GRPOTrainer(config)
trainer.train(dataset)


if __name__ == "__main__":
main()

四、DAPO:动态优势策略优化

DAPO 在传统策略优化基础上引入动态调整机制,能够更稳定地进行强化学习训练。

#!/usr/bin/env python3
"""
DAPO (Dynamic Actor Policy Optimization) Training Script
for Qwen3.5-0.8B Model

DAPO is an RLHF variant that combines:
- Dynamic sampling strategy (adjusting temperature based on reward variance)
- Data augmentation for improved robustness
- Optional dual-critic for better value estimation

Key Features:
1. Dynamic Temperature: Adjusts sampling temperature based on current model capability
2. Augmented Data: Uses prompt augmentation to improve training diversity
3. Progressive Learning: Gradually increases task difficulty

Usage:
Single GPU: python scripts/run_dapo.py
Multi GPU with DeepSpeed: deepspeed --num_gpus=2 scripts/run_dapo.py --deepspeed configs/ds_config_zero3.json

Reference:
- Related to Dynamic Sampling in RLHF
- Inspired by Curriculum Learning and Data Augmentation
"""

import os
import argparse
import json
import re
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)
from datasets import load_dataset, Dataset
from accelerate import Accelerator
from tqdm import tqdm


# ============================================================================
# Configuration
# ============================================================================

MODEL_PATH = "/workspace/models/Qwen3.5-0.8B"
OUTPUT_DIR = "/workspace/train_codes/output/dapo"
DATA_DIR = "/workspace/train_codes/data"

DEFAULT_DATASET = "tatsu-lab/alpaca"


@dataclass
class DAPOConfig:
"""
Configuration for DAPO training.
"""
model_path: str = MODEL_PATH
output_dir: str = OUTPUT_DIR
dataset_name: str = DEFAULT_DATASET
max_samples: int = 1000
max_prompt_length: int = 512
max_response_length: int = 512

# DAPO specific parameters
num_samples_per_prompt: int = 4 # Number of samples per prompt
initial_temperature: float = 1.0 # Initial sampling temperature
min_temperature: float = 0.5 # Minimum temperature
max_temperature: float = 1.5 # Maximum temperature
temperature_decay: float = 0.95 # Temperature decay factor
reward_threshold_high: float = 0.8 # High reward threshold
reward_threshold_low: float = 0.3 # Low reward threshold

# PPO-style parameters
beta: float = 0.1 # KL divergence coefficient
clip_coef: float = 0.2 # PPO clipping coefficient
gamma: float = 0.99 # Discount factor for value estimation

# Data augmentation
use_augmentation: bool = True # Enable prompt augmentation
augmentation_ratio: float = 0.3 # Ratio of augmented data

# Training parameters
learning_rate: float = 1e-6
critic_learning_rate: float = 5e-6 # Learning rate for critic
num_epochs: int = 1
batch_size: int = 4
gradient_accumulation_steps: int = 8
warmup_ratio: float = 0.03
weight_decay: float = 0.01

# Generation parameters
top_p: float = 0.95

# DeepSpeed
deepspeed: Optional[str] = None

# Other
seed: int = 42
logging_steps: int = 10
save_steps: int = 100

# Curriculum learning
use_curriculum: bool = True # Enable curriculum learning
curriculum_start_ratio: float = 0.3 # Start with easier samples


# ============================================================================
# Data Augmentation Functions
# ============================================================================

class PromptAugmenter:
"""
Augment prompts to improve training diversity.

Strategies:
1. Paraphrasing: Rewrite the prompt in different words
2. Context addition: Add relevant context
3. Question reformulation: Rephrase the question
"""

def __init__(self):
self.paraphrase_templates = [
lambda x: f"Can you explain {x}?",
lambda x: f"Please describe {x} in detail.",
lambda x: f"What do you know about {x}?",
lambda x: f"I'd like to understand {x}.",
]

def augment(self, prompt: str) -> List[str]:
"""
Generate augmented versions of the prompt.

Args:
prompt: Original prompt

Returns:
augmented_prompts: List of augmented prompts
"""
augmented = [prompt] # Always include original

# Simple augmentation: add context
if len(prompt) < 100:
augmented.append(f"Context: This is a learning exercise.\n{prompt}")

# Reformulation for questions
if "?" in prompt:
# Try to paraphrase
for template in self.paraphrase_templates[:2]:
try:
# Extract the core question
core = prompt.replace("?", "").strip()
new_prompt = template(core)
augmented.append(new_prompt)
except:
pass

return augmented


# ============================================================================
# Dynamic Temperature Controller
# ============================================================================

class DynamicTemperatureController:
"""
Dynamically adjust sampling temperature based on reward feedback.

Strategy:
- High reward variance → Lower temperature (more exploitation)
- Low reward variance → Higher temperature (more exploration)
- Progressive decay over training
"""

def __init__(self, config: DAPOConfig):
self.config = config
self.current_temperature = config.initial_temperature
self.reward_history = []

def update(self, rewards: List[float]) -> float:
"""
Update temperature based on recent rewards.

Args:
rewards: List of recent rewards

Returns:
temperature: Updated sampling temperature
"""
self.reward_history.extend(rewards)

# Keep only recent history
if len(self.reward_history) > 100:
self.reward_history = self.reward_history[-100:]

# Compute reward statistics
if len(self.reward_history) > 10:
mean_reward = np.mean(self.reward_history)
std_reward = np.std(self.reward_history)

# Adjust temperature based on reward statistics
if mean_reward > self.config.reward_threshold_high:
# Model is doing well, reduce temperature for more exploitation
self.current_temperature = max(
self.config.min_temperature,
self.current_temperature * self.config.temperature_decay
)
elif mean_reward < self.config.reward_threshold_low:
# Model needs improvement, increase temperature for exploration
self.current_temperature = min(
self.config.max_temperature,
self.current_temperature / self.config.temperature_decay
)

return self.current_temperature

def get_temperature(self) -> float:
"""Get current temperature."""
return self.current_temperature


# ============================================================================
# DAPO Core Algorithm
# ============================================================================

def compute_dapo_advantage(
rewards: torch.Tensor,
values: Optional[torch.Tensor] = None,
gamma: float = 0.99,
use_gae: bool = True,
gae_lambda: float = 0.95,
) -> torch.Tensor:
"""
Compute DAPO advantage using dynamic value estimation.

If values are provided (dual-critic mode), use GAE.
Otherwise, use mean reward as baseline.

Args:
rewards: Tensor of rewards
values: Optional tensor of value estimates
gamma: Discount factor
use_gae: Whether to use Generalized Advantage Estimation
gae_lambda: GAE lambda parameter

Returns:
advantages: Tensor of advantages
"""
if values is not None and use_gae:
# Use GAE for advantage estimation
advantages = []
gae = 0

for i in reversed(range(len(rewards))):
if i == len(rewards) - 1:
next_value = 0
else:
next_value = values[i + 1]

delta = rewards[i] + gamma * next_value - values[i]
gae = delta + gamma * gae_lambda * gae
advantages.insert(0, gae)

advantages = torch.tensor(advantages, dtype=torch.float32)
else:
# Use mean reward as baseline
mean_reward = rewards.mean()
advantages = rewards - mean_reward

# Normalize advantages
if advantages.std() > 0:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

return advantages


def compute_dapo_loss(
policy_log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
advantages: torch.Tensor,
beta: float = 0.1,
clip_coef: float = 0.2,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute DAPO loss with dynamic clipping.

Loss = -min(ratio * A, clip(ratio) * A) + beta * KL

Args:
policy_log_probs: Current policy log probs
old_log_probs: Old policy log probs (from sampling)
ref_log_probs: Reference model log probs
advantages: Advantage estimates
beta: KL coefficient
clip_coef: Clipping coefficient

Returns:
loss: Total loss
metrics: Dictionary of loss components
"""
# Compute ratios
ratio = torch.exp(policy_log_probs - old_log_probs)
ref_ratio = torch.exp(policy_log_probs - ref_log_probs)

# Policy loss with clipping
clipped_ratio = torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef)
policy_loss1 = ratio * advantages
policy_loss2 = clipped_ratio * advantages
policy_loss = -torch.min(policy_loss1, policy_loss2).mean()

# KL divergence penalty
kl_loss = beta * (ref_log_probs - policy_log_probs).mean()

# Entropy bonus (optional, encourage exploration)
entropy = -(policy_log_probs.exp() * policy_log_probs).sum(dim=-1).mean()

# Total loss
total_loss = policy_loss + kl_loss - 0.01 * entropy # Small entropy bonus

metrics = {
"total_loss": total_loss.item(),
"policy_loss": policy_loss.item(),
"kl_loss": kl_loss.item(),
"entropy": entropy.item(),
"ratio_mean": ratio.mean().item(),
"ratio_std": ratio.std().item() if ratio.numel() > 1 else 0,
}

return total_loss, metrics


# ============================================================================
# Reward Function
# ============================================================================

class DAPORewardFunction:
"""
Reward function for DAPO training.

Combines multiple reward signals:
1. Format reward: Proper response formatting
2. Quality reward: Response quality metrics
3. Progress reward: Reward for improvement over iterations
"""

def __init__(self, progress_tracker=None):
self.progress_tracker = progress_tracker

def compute(self, prompt: str, response: str, iteration: int = 0) -> float:
"""
Compute reward for a prompt-response pair.

Args:
prompt: Input prompt
response: Generated response
iteration: Current training iteration

Returns:
reward: Scalar reward value
"""
reward = 0.0

# Format reward
if response.strip().endswith('.') or response.strip().endswith('?') or response.strip().endswith('!'):
reward += 0.2

# Length reward (prefer moderate length)
length = len(response)
if 20 <= length <= 300:
reward += 0.2
elif 10 <= length < 20:
reward += 0.1
elif length > 300:
reward -= 0.1 # Penalize overly long responses

# Coherence reward (no excessive repetition)
words = response.split()
if len(words) > 0:
unique_ratio = len(set(words)) / len(words)
if unique_ratio > 0.7:
reward += 0.2
elif unique_ratio < 0.3:
reward -= 0.2 # Penalize repetition

# Progress reward (if tracker available)
if self.progress_tracker is not None:
progress_reward = self.progress_tracker.get_progress_reward(prompt, iteration)
reward += progress_reward * 0.2

return reward


class ProgressTracker:
"""
Track training progress and provide progress-based rewards.
"""

def __init__(self):
self.prompt_rewards = {} # Track reward history per prompt
self.iteration = 0

def update(self, prompt: str, reward: float):
"""Update reward history for a prompt."""
if prompt not in self.prompt_rewards:
self.prompt_rewards[prompt] = []
self.prompt_rewards[prompt].append(reward)
self.iteration += 1

def get_progress_reward(self, prompt: str, iteration: int) -> float:
"""
Get progress reward based on improvement history.

Args:
prompt: Input prompt
iteration: Current iteration

Returns:
progress_reward: Reward based on improvement
"""
if prompt not in self.prompt_rewards or len(self.prompt_rewards[prompt]) < 2:
return 0.0

history = self.prompt_rewards[prompt]
if len(history) >= 2:
# Reward improvement
improvement = history[-1] - history[-2]
return max(0, improvement) # Only reward positive improvement

return 0.0


# ============================================================================
# DAPO Trainer Class
# ============================================================================

class DAPOTrainer:
"""
DAPO Trainer implementing Dynamic Actor Policy Optimization.
"""

def __init__(self, config: DAPOConfig):
self.config = config

# Set seed
torch.manual_seed(config.seed)
np.random.seed(config.seed)

# Initialize components
self.augmenter = PromptAugmenter() if config.use_augmentation else None
self.temperature_controller = DynamicTemperatureController(config)
self.progress_tracker = ProgressTracker()
self.reward_function = DAPORewardFunction(self.progress_tracker)

# Load models
print("Loading policy model...")
self.policy_model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)

print("Loading reference model...")
self.ref_model = AutoModelForCausalLM.from_pretrained(
config.model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.ref_model.eval()
for param in self.ref_model.parameters():
param.requires_grad = False

# Optional: Value network (Critic) for dual-critic mode
# For simplicity, we use a shared embedding approach
# In practice, you would train a separate value head

# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
config.model_path,
trust_remote_code=True,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# Setup accelerator
self.accelerator = Accelerator()

# Optimizer
self.optimizer = AdamW(
self.policy_model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay,
)

# Prepare models
self.policy_model, self.optimizer = self.accelerator.prepare(
self.policy_model, self.optimizer
)
self.ref_model.to(self.accelerator.device)

# Training state
self.global_step = 0
self.best_reward = -float('inf')

def generate_samples(
self,
prompts: List[str],
temperature: float,
) -> Tuple[List[str], List[str], List[torch.Tensor]]:
"""
Generate samples with dynamic temperature.

Args:
prompts: List of input prompts
temperature: Sampling temperature

Returns:
prompts_expanded: Expanded prompts
responses: Generated responses
old_log_probs: Log probabilities at generation time
"""
prompts_expanded = []
responses = []
old_log_probs = []

generation_config = GenerationConfig(
max_new_tokens=self.config.max_response_length,
temperature=temperature,
top_p=self.config.top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

for prompt in prompts:
# Optionally augment prompt
if self.augmenter and np.random.random() < self.config.augmentation_ratio:
augmented_prompts = self.augmenter.augment(prompt)
prompt = np.random.choice(augmented_prompts)

inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=self.config.max_prompt_length
)
inputs = {k: v.to(self.accelerator.device) for k, v in inputs.items()}

for _ in range(self.config.num_samples_per_prompt):
with torch.no_grad():
# Generate response
outputs = self.policy_model.generate(
**inputs,
generation_config=generation_config,
)

response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)

# Compute log prob at generation time
full_text = prompt + response
full_inputs = self.tokenizer(full_text, return_tensors="pt")
full_inputs = {k: v.to(self.accelerator.device) for k, v in full_inputs.items()}

prompt_length = inputs["input_ids"].shape[1]

with torch.no_grad():
model_outputs = self.policy_model(**full_inputs)
logits = model_outputs.logits

response_logits = logits[0, prompt_length-1:-1, :]
response_tokens = full_inputs["input_ids"][0, prompt_length:]

log_prob = F.log_softmax(response_logits, dim=-1)
token_log_probs = log_prob[range(len(response_tokens)), response_tokens]
total_log_prob = token_log_probs.sum()

prompts_expanded.append(prompt)
responses.append(response)
old_log_probs.append(total_log_prob)

return prompts_expanded, responses, torch.stack(old_log_probs)

def compute_log_probs(
self,
model: nn.Module,
prompts: List[str],
responses: List[str],
) -> torch.Tensor:
"""
Compute log probabilities for prompt-response pairs.
"""
log_probs = []

for prompt, response in zip(prompts, responses):
full_text = prompt + response
inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True)
inputs = {k: v.to(self.accelerator.device) for k, v in inputs.items()}

prompt_inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True)
prompt_length = prompt_inputs["input_ids"].shape[1]

with torch.no_grad() if model == self.ref_model else torch.enable_grad():
outputs = model(**inputs)
logits = outputs.logits

response_logits = logits[0, prompt_length-1:-1, :]
response_tokens = inputs["input_ids"][0, prompt_length:]

log_prob = F.log_softmax(response_logits, dim=-1)
token_log_probs = log_prob[range(len(response_tokens)), response_tokens]
total_log_prob = token_log_probs.sum()

log_probs.append(total_log_prob)

return torch.stack(log_probs)

def compute_rewards(
self,
prompts: List[str],
responses: List[str],
) -> torch.Tensor:
"""
Compute rewards for all samples.
"""
rewards = []
for prompt, response in zip(prompts, responses):
reward = self.reward_function.compute(prompt, response, self.global_step)
rewards.append(reward)

# Update progress tracker
self.progress_tracker.update(prompt, reward)

return torch.tensor(rewards, dtype=torch.float32, device=self.accelerator.device)

def train_step(self, prompts: List[str]) -> Dict[str, float]:
"""
Perform one DAPO training step.

Args:
prompts: List of input prompts

Returns:
metrics: Training metrics
"""
# Get dynamic temperature
temperature = self.temperature_controller.get_temperature()

# Generate samples
all_prompts, all_responses, old_log_probs = self.generate_samples(
prompts, temperature
)

# Compute rewards
rewards = self.compute_rewards(all_prompts, all_responses)

# Update temperature controller
temperature = self.temperature_controller.update(rewards.tolist())

# Compute advantages
advantages = compute_dapo_advantage(
rewards,
values=None, # No critic for simplicity
gamma=self.config.gamma,
)

# Compute current policy log probs
policy_log_probs = self.compute_log_probs(self.policy_model, all_prompts, all_responses)

# Compute reference log probs
ref_log_probs = self.compute_log_probs(self.ref_model, all_prompts, all_responses)

# Compute loss
loss, metrics = compute_dapo_loss(
policy_log_probs,
old_log_probs.to(self.accelerator.device),
ref_log_probs,
advantages,
beta=self.config.beta,
clip_coef=self.config.clip_coef,
)

# Add reward metrics
metrics["mean_reward"] = rewards.mean().item()
metrics["std_reward"] = rewards.std().item() if rewards.numel() > 1 else 0
metrics["temperature"] = temperature

# Backward and optimize
self.accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()

# Update best reward
if metrics["mean_reward"] > self.best_reward:
self.best_reward = metrics["mean_reward"]

self.global_step += 1

return metrics

def train(self, dataset: Dataset):
"""
Main training loop.

Args:
dataset: Training dataset
"""
print(f"Starting DAPO training on {len(dataset)} samples...")
print(f"Dynamic temperature: {self.config.initial_temperature}")
print(f"Data augmentation: {self.config.use_augmentation}")

# Create dataloader
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size,
shuffle=True,
)

num_batches = len(dataloader) * self.config.num_epochs

for epoch in range(self.config.num_epochs):
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")):
# Extract prompts
if "instruction" in batch:
prompts = batch["instruction"]
elif "prompt" in batch:
prompts = batch["prompt"]
else:
prompts = batch["text"]

if isinstance(prompts, torch.Tensor):
prompts = prompts.tolist()

metrics = self.train_step(prompts)

# Logging
if self.global_step % self.config.logging_steps == 0:
print(
f"Step {self.global_step}/{num_batches}: "
f"loss={metrics['total_loss']:.4f}, "
f"reward={metrics['mean_reward']:.4f}±{metrics['std_reward']:.4f}, "
f"temp={metrics['temperature']:.3f}, "
f"kl={metrics['kl_loss']:.4f}"
)

# Save checkpoint
if self.global_step % self.config.save_steps == 0:
self.save_checkpoint(self.global_step)

# Save final model
self.save_checkpoint(self.global_step, final=True)
print(f"Training complete! Model saved to: {self.config.output_dir}")
print(f"Best reward achieved: {self.best_reward:.4f}")

def save_checkpoint(self, step: int, final: bool = False):
"""
Save model checkpoint.
"""
save_dir = (
self.config.output_dir
if final
else os.path.join(self.config.output_dir, f"checkpoint-{step}")
)

os.makedirs(save_dir, exist_ok=True)

unwrapped_model = self.accelerator.unwrap_model(self.policy_model)
unwrapped_model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)

# Save training state
state = {
"global_step": self.global_step,
"best_reward": self.best_reward,
"temperature": self.temperature_controller.get_temperature(),
}
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump(state, f)

print(f"Checkpoint saved to: {save_dir}")


# ============================================================================
# Dataset Preparation
# ============================================================================

def prepare_dataset(config: DAPOConfig) -> Dataset:
"""
Load and prepare dataset for DAPO training.
"""
print(f"Loading dataset: {config.dataset_name}")

try:
dataset = load_dataset(config.dataset_name, split="train")
except Exception as e:
print(f"Failed to load dataset: {e}")
local_path = Path(DATA_DIR) / f"{config.dataset_name}.json"
if local_path.exists():
with open(local_path, "r") as f:
data = json.load(f)
dataset = Dataset.from_list(data)
else:
raise ValueError(f"Dataset not found: {config.dataset_name}")

if config.max_samples > 0 and len(dataset) > config.max_samples:
dataset = dataset.select(range(config.max_samples))

def format_prompt(example):
if "instruction" in example:
instruction = example["instruction"]
input_text = example.get("input", "")
if input_text:
return {"prompt": f"Instruction: {instruction}\nInput: {input_text}\nResponse:"}
else:
return {"prompt": f"Instruction: {instruction}\nResponse:"}
elif "prompt" in example:
return {"prompt": example["prompt"]}
else:
return {"prompt": str(example)}

dataset = dataset.map(format_prompt)

print(f"Dataset prepared: {len(dataset)} samples")
return dataset


# ============================================================================
# Main Entry Point
# ============================================================================

def main():
parser = argparse.ArgumentParser(description="DAPO Training Script")
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR)
parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATASET)
parser.add_argument("--max_samples", type=int, default=1000)
parser.add_argument("--num_samples_per_prompt", type=int, default=4)
parser.add_argument("--initial_temperature", type=float, default=1.0)
parser.add_argument("--beta", type=float, default=0.1)
parser.add_argument("--learning_rate", type=float, default=1e-6)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--use_augmentation", type=bool, default=True)
parser.add_argument("--deepspeed", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()

config = DAPOConfig(
model_path=args.model_path,
output_dir=args.output_dir,
dataset_name=args.dataset_name,
max_samples=args.max_samples,
num_samples_per_prompt=args.num_samples_per_prompt,
initial_temperature=args.initial_temperature,
beta=args.beta,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
use_augmentation=args.use_augmentation,
deepspeed=args.deepspeed,
seed=args.seed,
)

print("=" * 60)
print("DAPO Training Configuration")
print("=" * 60)
print(f"Model: {config.model_path}")
print(f"Output: {config.output_dir}")
print(f"Dataset: {config.dataset_name}")
print(f"Samples per prompt: {config.num_samples_per_prompt}")
print(f"Initial temperature: {config.initial_temperature}")
print(f"Data augmentation: {config.use_augmentation}")
print(f"Beta (KL coefficient): {config.beta}")
print("=" * 60)

dataset = prepare_dataset(config)

trainer = DAPOTrainer(config)
trainer.train(dataset)


if __name__ == "__main__":
main()

五、On-Policy Distillation:在线策略蒸馏

在线策略蒸馏将知识蒸馏与策略优化相结合,在蒸馏过程中保持策略的在线更新,实现高效的知识迁移。

#!/usr/bin/env python3
"""
On-Policy Distillation Training Script
for Qwen3.5-0.8B Model

On-Policy Distillation differs from traditional distillation:
- Traditional: Static teacher outputs from pre-collected data
- On-Policy: Student generates samples, Teacher provides dynamic guidance

Key advantages:
1. Student learns from distribution it actually encounters
2. No distribution mismatch between training and deployment
3. Adaptive curriculum - harder examples get more attention

Algorithm:
1. Student generates responses for prompts
2. Teacher generates reference responses
3. Minimize KL(Student || Teacher) on student's generated distribution
4. Optionally: Use reward model to select high-quality teacher samples

Reference:
- Policy Distillation (Rusu et al., 2015): https://arxiv.org/abs/1511.06295
- Knowledge Distillation Survey: https://arxiv.org/abs/2006.05525
"""

import os
import argparse
import json
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
)
from datasets import load_dataset, Dataset
from accelerate import Accelerator
from tqdm import tqdm


# ============================================================================
# Configuration
# ============================================================================

STUDENT_MODEL_PATH = "/workspace/models/Qwen3.5-0.8B"
TEACHER_MODEL_PATH = "/workspace/models/Qwen3.5-0.8B" # Can be different (larger) model
OUTPUT_DIR = "/workspace/train_codes/output/on_policy_distill"
DATA_DIR = "/workspace/train_codes/data"
DEFAULT_DATASET = "tatsu-lab/alpaca"


@dataclass
class OnPolicyDistillConfig:
"""Configuration for On-Policy Distillation training."""
student_model_path: str = STUDENT_MODEL_PATH
teacher_model_path: str = TEACHER_MODEL_PATH
output_dir: str = OUTPUT_DIR
dataset_name: str = DEFAULT_DATASET
max_samples: int = 1000
max_prompt_length: int = 512
max_response_length: int = 512

# Distillation parameters
temperature: float = 2.0 # KD temperature for soft labels
alpha: float = 0.5 # Weight for KL loss vs hard label loss
beta: float = 0.1 # KL coefficient
use_hard_labels: bool = True # Also use ground truth if available
teacher_samples_per_prompt: int = 1 # Number of teacher samples

# Training parameters
learning_rate: float = 1e-5
num_epochs: int = 1
batch_size: int = 4
gradient_accumulation_steps: int = 4
weight_decay: float = 0.01
max_grad_norm: float = 1.0

# Generation parameters
student_temperature: float = 1.0 # Temperature for student generation
top_p: float = 0.95

# Other
deepspeed: Optional[str] = None
seed: int = 42
logging_steps: int = 10
save_steps: int = 100


# ============================================================================
# On-Policy Distillation Core Algorithm
# ============================================================================

def compute_per_token_log_probs(
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""
Compute per-token log probabilities for response tokens.

Args:
model: Language model
input_ids: Full sequence [batch, seq_len]
attention_mask: Attention mask [batch, seq_len]
response_mask: Mask indicating response tokens [batch, seq_len]

Returns:
log_probs: Per-sample log probabilities [batch]
"""
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1, :] # [batch, seq_len-1, vocab]

target_ids = input_ids[:, 1:] # [batch, seq_len-1]
log_probs = F.log_softmax(logits, dim=-1)

token_log_probs = torch.gather(
log_probs, dim=-1, index=target_ids.unsqueeze(-1)
).squeeze(-1)

response_mask_shifted = response_mask[:, 1:]
response_lengths = response_mask_shifted.sum(dim=-1).clamp(min=1)
token_log_probs_masked = token_log_probs * response_mask_shifted
log_prob_sum = token_log_probs_masked.sum(dim=-1)

return log_prob_sum / response_lengths


def compute_soft_log_probs(
model: nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
response_mask: torch.Tensor,
temperature: float = 2.0,
) -> torch.Tensor:
"""
Compute soft log probabilities with temperature scaling.

Higher temperature = softer distribution (more uniform)
This is used for Knowledge Distillation.
"""
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits[:, :-1, :] / temperature # Temperature scaling

log_probs = F.log_softmax(logits, dim=-1)
target_ids = input_ids[:, 1:]

token_log_probs = torch.gather(
log_probs, dim=-1, index=target_ids.unsqueeze(-1)
).squeeze(-1)

response_mask_shifted = response_mask[:, 1:]
response_lengths = response_mask_shifted.sum(dim=-1).clamp(min=1)
token_log_probs_masked = token_log_probs * response_mask_shifted
log_prob_sum = token_log_probs_masked.sum(dim=-1)

return log_prob_sum / response_lengths


def compute_kl_divergence_distill(
student_log_probs: torch.Tensor,
teacher_log_probs: torch.Tensor,
temperature: float = 2.0,
) -> torch.Tensor:
"""
Compute KL divergence for distillation: KL(Student || Teacher)

For knowledge distillation, we want student to match teacher's distribution.

KL(P || Q) = sum(P(x) * log(P(x)/Q(x)))

With temperature scaling:
KL = T^2 * KL(student_soft || teacher_soft)
"""
# Scale by T^2 to maintain gradient magnitude (Hinton et al., 2015)
kl = temperature ** 2 * (student_log_probs - teacher_log_probs).exp() * (student_log_probs - teacher_log_probs)
return kl


def compute_distillation_loss(
student_log_probs: torch.Tensor,
teacher_soft_log_probs: torch.Tensor,
hard_label_log_probs: Optional[torch.Tensor] = None,
alpha: float = 0.5,
temperature: float = 2.0,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute On-Policy Distillation loss.

Loss = alpha * KL(Student || Teacher) + (1 - alpha) * Hard_Label_Loss

Args:
student_log_probs: Student's log probabilities (with grad)
teacher_soft_log_probs: Teacher's soft log probs (detached, scaled by T)
hard_label_log_probs: Optional ground truth log probs
alpha: Weight for soft loss vs hard loss
temperature: KD temperature

Returns:
loss: Total distillation loss
metrics: Dictionary of metrics
"""
# Soft loss: KL divergence with teacher
kl_loss = compute_kl_divergence_distill(
student_log_probs,
teacher_soft_log_probs,
temperature
)
soft_loss = kl_loss.mean()

# Hard loss: Cross-entropy with ground truth (if available)
hard_loss = torch.tensor(0.0)
if hard_label_log_probs is not None:
hard_loss = -hard_label_log_probs.mean()

# Total loss
total_loss = alpha * soft_loss + (1 - alpha) * hard_loss

metrics = {
"total_loss": total_loss.item(),
"soft_loss": soft_loss.item(),
"hard_loss": hard_loss.item(),
"kl_divergence": kl_loss.mean().item(),
}

return total_loss, metrics


# ============================================================================
# On-Policy Distillation Trainer
# ============================================================================

class OnPolicyDistillTrainer:
"""
On-Policy Distillation Trainer.

Student learns from Teacher on the distribution Student actually generates.
This avoids distribution mismatch in traditional distillation.
"""

def __init__(self, config: OnPolicyDistillConfig):
self.config = config
self.accelerator = Accelerator()

torch.manual_seed(config.seed)
np.random.seed(config.seed)

# Load student model (to be trained)
print("Loading student model...")
self.student_model = AutoModelForCausalLM.from_pretrained(
config.student_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)

# Load teacher model (frozen, provides guidance)
print("Loading teacher model...")
self.teacher_model = AutoModelForCausalLM.from_pretrained(
config.teacher_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.teacher_model.eval()
for param in self.teacher_model.parameters():
param.requires_grad = False

# Tokenizer (use student's tokenizer)
self.tokenizer = AutoTokenizer.from_pretrained(
config.student_model_path, trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# Generation configs
self.student_gen_config = GenerationConfig(
max_new_tokens=config.max_response_length,
temperature=config.student_temperature,
top_p=config.top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

self.teacher_gen_config = GenerationConfig(
max_new_tokens=config.max_response_length,
temperature=config.temperature, # Higher for diverse outputs
top_p=config.top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

# Optimizer
self.optimizer = AdamW(
self.student_model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay,
)

# Prepare with accelerator
self.student_model, self.optimizer = self.accelerator.prepare(
self.student_model, self.optimizer
)
self.teacher_model.to(self.accelerator.device)

self.global_step = 0

def generate_student_samples(
self,
prompts: List[str],
) -> Tuple[List[str], List[str], torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate samples using student model.

These are "on-policy" samples - from the distribution student will encounter.
"""
prompts_expanded = []
responses = []
all_full_texts = []
all_prompt_lengths = []

for prompt in prompts:
prompt_tokens = self.tokenizer(
prompt, return_tensors="pt", truncation=True,
max_length=self.config.max_prompt_length,
add_special_tokens=False
)
prompt_length = prompt_tokens["input_ids"].shape[1]

# Generate one sample per prompt (on-policy)
with torch.no_grad():
gen_inputs = {k: v.to(self.accelerator.device) for k, v in prompt_tokens.items()}
gen_outputs = self.student_model.generate(
**gen_inputs,
generation_config=self.student_gen_config,
)

response_ids = gen_outputs[0][prompt_length:]
response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
response_length = response_ids.shape[0]

all_full_texts.append(prompt + response)
all_prompt_lengths.append(prompt_length)
prompts_expanded.append(prompt)
responses.append(response)

# Batch tokenize with padding
full_tokens = self.tokenizer(
all_full_texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=self.config.max_prompt_length + self.config.max_response_length,
)

input_ids = full_tokens["input_ids"].to(self.accelerator.device)
attention_mask = full_tokens["attention_mask"].to(self.accelerator.device)

# Create response masks
response_mask = torch.zeros_like(input_ids)
for i, prompt_len in enumerate(all_prompt_lengths):
actual_seq_len = attention_mask[i].sum().item()
response_end = min(input_ids.shape[1], actual_seq_len)
if response_end > prompt_len:
response_mask[i, prompt_len:response_end] = 1

return prompts_expanded, responses, input_ids, attention_mask, response_mask

def generate_teacher_samples(
self,
prompts: List[str],
) -> List[str]:
"""
Generate reference samples using teacher model.

Teacher provides guidance for student's generated distribution.
"""
teacher_responses = []

for prompt in prompts:
prompt_tokens = self.tokenizer(
prompt, return_tensors="pt", truncation=True,
max_length=self.config.max_prompt_length,
)
prompt_length = prompt_tokens["input_ids"].shape[1]

with torch.no_grad():
gen_inputs = {k: v.to(self.accelerator.device) for k, v in prompt_tokens.items()}
gen_outputs = self.teacher_model.generate(
**gen_inputs,
generation_config=self.teacher_gen_config,
)

response = self.tokenizer.decode(
gen_outputs[0][prompt_length:],
skip_special_tokens=True,
)
teacher_responses.append(response)

return teacher_responses

def train_step(self, prompts: List[str]) -> Dict[str, float]:
"""
One On-Policy Distillation training step.

Steps:
1. Student generates responses (on-policy samples)
2. Teacher generates reference responses
3. Compute student log probs
4. Compute teacher soft log probs (with temperature)
5. Compute distillation loss
6. Update student
"""
# Step 1: Generate student samples
prompts_expanded, student_responses, input_ids, attention_mask, response_mask = \
self.generate_student_samples(prompts)

# Step 2: Generate teacher responses for comparison
teacher_responses = self.generate_teacher_samples(prompts_expanded)

# Create teacher sequences
teacher_full_texts = [p + r for p, r in zip(prompts_expanded, teacher_responses)]
teacher_tokens = self.tokenizer(
teacher_full_texts,
return_tensors="pt",
truncation=True,
padding=True,
max_length=self.config.max_prompt_length + self.config.max_response_length,
)
teacher_input_ids = teacher_tokens["input_ids"].to(self.accelerator.device)
teacher_attention_mask = teacher_tokens["attention_mask"].to(self.accelerator.device)

# Create teacher response mask (same positions as student's prompts)
teacher_response_mask = torch.zeros_like(teacher_input_ids)
prompt_tokens = self.tokenizer(prompts_expanded, add_special_tokens=False)
for i, prompt_ids in enumerate(prompt_tokens["input_ids"]):
prompt_len = len(prompt_ids)
actual_seq_len = teacher_attention_mask[i].sum().item()
response_end = min(teacher_input_ids.shape[1], actual_seq_len)
if response_end > prompt_len:
teacher_response_mask[i, prompt_len:response_end] = 1

# Step 3: Compute student log probs (with grad)
student_log_probs = compute_per_token_log_probs(
self.student_model,
input_ids,
attention_mask,
response_mask,
)

# Step 4: Compute teacher soft log probs (with temperature, no grad)
with torch.no_grad():
teacher_soft_log_probs = compute_soft_log_probs(
self.teacher_model,
teacher_input_ids,
teacher_attention_mask,
teacher_response_mask,
temperature=self.config.temperature,
)

# Step 5: Compute distillation loss
loss, metrics = compute_distillation_loss(
student_log_probs,
teacher_soft_log_probs,
hard_label_log_probs=None, # No ground truth in this setup
alpha=1.0, # Pure distillation
temperature=self.config.temperature,
)

# Check for NaN/Inf
if torch.isnan(loss) or torch.isinf(loss):
print("Warning: NaN/Inf loss, skipping update")
self.optimizer.zero_grad()
return {"total_loss": 0.0, "error": "NaN/Inf"}

# Step 6: Update student with gradient clipping
self.accelerator.backward(loss)
self.accelerator.clip_grad_norm_(
self.student_model.parameters(),
max_norm=self.config.max_grad_norm
)
self.optimizer.step()
self.optimizer.zero_grad()

self.global_step += 1

return metrics

def train(self, dataset: Dataset):
"""Main training loop."""
print(f"Starting On-Policy Distillation on {len(dataset)} samples...")
print(f"Temperature: {self.config.temperature}")
print(f"Alpha (soft loss weight): {self.config.alpha}")

dataloader = DataLoader(
dataset, batch_size=self.config.batch_size, shuffle=True
)

num_batches = len(dataloader) * self.config.num_epochs

for epoch in range(self.config.num_epochs):
print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")

for batch in tqdm(dataloader, desc="Training"):
# Extract prompts
if "instruction" in batch:
prompts = batch["instruction"]
elif "prompt" in batch:
prompts = batch["prompt"]
else:
prompts = batch["text"]

if isinstance(prompts, torch.Tensor):
prompts = prompts.tolist()

metrics = self.train_step(prompts)

if self.global_step % self.config.logging_steps == 0:
print(
f"Step {self.global_step}/{num_batches}: "
f"loss={metrics['total_loss']:.4f}, "
f"soft_loss={metrics['soft_loss']:.4f}, "
f"kl={metrics['kl_divergence']:.4f}"
)

if self.global_step % self.config.save_steps == 0:
self.save_checkpoint(self.global_step)

self.save_checkpoint(self.global_step, final=True)
print(f"Training complete! Model saved to: {self.config.output_dir}")

def save_checkpoint(self, step: int, final: bool = False):
"""Save model checkpoint."""
save_dir = (
self.config.output_dir
if final
else os.path.join(self.config.output_dir, f"checkpoint-{step}")
)

unwrapped_model = self.accelerator.unwrap_model(self.student_model)
unwrapped_model.save_pretrained(save_dir)
self.tokenizer.save_pretrained(save_dir)
print(f"Checkpoint saved to: {save_dir}")


# ============================================================================
# Dataset Preparation
# ============================================================================

def prepare_dataset(config: OnPolicyDistillConfig) -> Dataset:
"""Load and prepare dataset."""
print(f"Loading dataset: {config.dataset_name}")

try:
dataset = load_dataset(config.dataset_name, split="train")
except Exception as e:
print(f"Failed to load dataset: {e}")
local_path = Path(DATA_DIR) / f"{config.dataset_name}.json"
if local_path.exists():
with open(local_path, "r") as f:
data = json.load(f)
dataset = Dataset.from_list(data)
else:
raise ValueError(f"Dataset not found: {config.dataset_name}")

if config.max_samples > 0 and len(dataset) > config.max_samples:
dataset = dataset.select(range(config.max_samples))

def format_prompt(example):
if "instruction" in example:
instruction = example["instruction"]
input_text = example.get("input", "")
if input_text:
return {"prompt": f"Instruction: {instruction}\nInput: {input_text}\nResponse:"}
return {"prompt": f"Instruction: {instruction}\nResponse:"}
elif "prompt" in example:
return {"prompt": example["prompt"]}
else:
return {"prompt": str(example)}

dataset = dataset.map(format_prompt)
print(f"Dataset prepared: {len(dataset)} samples")
return dataset


# ============================================================================
# Main
# ============================================================================

def main():
parser = argparse.ArgumentParser(description="On-Policy Distillation Training Script")
parser.add_argument("--student_model_path", type=str, default=STUDENT_MODEL_PATH)
parser.add_argument("--teacher_model_path", type=str, default=TEACHER_MODEL_PATH)
parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR)
parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATASET)
parser.add_argument("--max_samples", type=int, default=1000)
parser.add_argument("--temperature", type=float, default=2.0)
parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--deepspeed", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)

args = parser.parse_args()

config = OnPolicyDistillConfig(
student_model_path=args.student_model_path,
teacher_model_path=args.teacher_model_path,
output_dir=args.output_dir,
dataset_name=args.dataset_name,
max_samples=args.max_samples,
temperature=args.temperature,
alpha=args.alpha,
learning_rate=args.learning_rate,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
deepspeed=args.deepspeed,
seed=args.seed,
)

print("=" * 60)
print("On-Policy Distillation Configuration")
print("=" * 60)
print(f"Student: {config.student_model_path}")
print(f"Teacher: {config.teacher_model_path}")
print(f"Dataset: {config.dataset_name}")
print(f"Temperature: {config.temperature}")
print(f"Alpha: {config.alpha}")
print("=" * 60)

dataset = prepare_dataset(config)
trainer = OnPolicyDistillTrainer(config)
trainer.train(dataset)


if __name__ == "__main__":
main()