""" DAPO (Dynamic Actor Policy Optimization) Training Script for Qwen3.5-0.8B Model
DAPO is an RLHF variant that combines: - Dynamic sampling strategy (adjusting temperature based on reward variance) - Data augmentation for improved robustness - Optional dual-critic for better value estimation
Key Features: 1. Dynamic Temperature: Adjusts sampling temperature based on current model capability 2. Augmented Data: Uses prompt augmentation to improve training diversity 3. Progressive Learning: Gradually increases task difficulty
Usage: Single GPU: python scripts/run_dapo.py Multi GPU with DeepSpeed: deepspeed --num_gpus=2 scripts/run_dapo.py --deepspeed configs/ds_config_zero3.json
Reference: - Related to Dynamic Sampling in RLHF - Inspired by Curriculum Learning and Data Augmentation """
import os import argparse import json import re from pathlib import Path from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Tuple import numpy as np
import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.utils.data import DataLoader from transformers import ( AutoModelForCausalLM, AutoTokenizer, GenerationConfig, ) from datasets import load_dataset, Dataset from accelerate import Accelerator from tqdm import tqdm
MODEL_PATH = "/workspace/models/Qwen3.5-0.8B" OUTPUT_DIR = "/workspace/train_codes/output/dapo" DATA_DIR = "/workspace/train_codes/data"
DEFAULT_DATASET = "tatsu-lab/alpaca"
@dataclass class DAPOConfig: """ Configuration for DAPO training. """ model_path: str = MODEL_PATH output_dir: str = OUTPUT_DIR dataset_name: str = DEFAULT_DATASET max_samples: int = 1000 max_prompt_length: int = 512 max_response_length: int = 512
num_samples_per_prompt: int = 4 initial_temperature: float = 1.0 min_temperature: float = 0.5 max_temperature: float = 1.5 temperature_decay: float = 0.95 reward_threshold_high: float = 0.8 reward_threshold_low: float = 0.3
beta: float = 0.1 clip_coef: float = 0.2 gamma: float = 0.99
use_augmentation: bool = True augmentation_ratio: float = 0.3
learning_rate: float = 1e-6 critic_learning_rate: float = 5e-6 num_epochs: int = 1 batch_size: int = 4 gradient_accumulation_steps: int = 8 warmup_ratio: float = 0.03 weight_decay: float = 0.01
top_p: float = 0.95
deepspeed: Optional[str] = None
seed: int = 42 logging_steps: int = 10 save_steps: int = 100
use_curriculum: bool = True curriculum_start_ratio: float = 0.3
class PromptAugmenter: """ Augment prompts to improve training diversity.
Strategies: 1. Paraphrasing: Rewrite the prompt in different words 2. Context addition: Add relevant context 3. Question reformulation: Rephrase the question """
def __init__(self): self.paraphrase_templates = [ lambda x: f"Can you explain {x}?", lambda x: f"Please describe {x} in detail.", lambda x: f"What do you know about {x}?", lambda x: f"I'd like to understand {x}.", ]
def augment(self, prompt: str) -> List[str]: """ Generate augmented versions of the prompt.
Args: prompt: Original prompt
Returns: augmented_prompts: List of augmented prompts """ augmented = [prompt]
if len(prompt) < 100: augmented.append(f"Context: This is a learning exercise.\n{prompt}")
if "?" in prompt: for template in self.paraphrase_templates[:2]: try: core = prompt.replace("?", "").strip() new_prompt = template(core) augmented.append(new_prompt) except: pass
return augmented
class DynamicTemperatureController: """ Dynamically adjust sampling temperature based on reward feedback.
Strategy: - High reward variance → Lower temperature (more exploitation) - Low reward variance → Higher temperature (more exploration) - Progressive decay over training """
def __init__(self, config: DAPOConfig): self.config = config self.current_temperature = config.initial_temperature self.reward_history = []
def update(self, rewards: List[float]) -> float: """ Update temperature based on recent rewards.
Args: rewards: List of recent rewards
Returns: temperature: Updated sampling temperature """ self.reward_history.extend(rewards)
if len(self.reward_history) > 100: self.reward_history = self.reward_history[-100:]
if len(self.reward_history) > 10: mean_reward = np.mean(self.reward_history) std_reward = np.std(self.reward_history)
if mean_reward > self.config.reward_threshold_high: self.current_temperature = max( self.config.min_temperature, self.current_temperature * self.config.temperature_decay ) elif mean_reward < self.config.reward_threshold_low: self.current_temperature = min( self.config.max_temperature, self.current_temperature / self.config.temperature_decay )
return self.current_temperature
def get_temperature(self) -> float: """Get current temperature.""" return self.current_temperature
def compute_dapo_advantage( rewards: torch.Tensor, values: Optional[torch.Tensor] = None, gamma: float = 0.99, use_gae: bool = True, gae_lambda: float = 0.95, ) -> torch.Tensor: """ Compute DAPO advantage using dynamic value estimation.
If values are provided (dual-critic mode), use GAE. Otherwise, use mean reward as baseline.
Args: rewards: Tensor of rewards values: Optional tensor of value estimates gamma: Discount factor use_gae: Whether to use Generalized Advantage Estimation gae_lambda: GAE lambda parameter
Returns: advantages: Tensor of advantages """ if values is not None and use_gae: advantages = [] gae = 0
for i in reversed(range(len(rewards))): if i == len(rewards) - 1: next_value = 0 else: next_value = values[i + 1]
delta = rewards[i] + gamma * next_value - values[i] gae = delta + gamma * gae_lambda * gae advantages.insert(0, gae)
advantages = torch.tensor(advantages, dtype=torch.float32) else: mean_reward = rewards.mean() advantages = rewards - mean_reward
if advantages.std() > 0: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
return advantages
def compute_dapo_loss( policy_log_probs: torch.Tensor, old_log_probs: torch.Tensor, ref_log_probs: torch.Tensor, advantages: torch.Tensor, beta: float = 0.1, clip_coef: float = 0.2, ) -> Tuple[torch.Tensor, Dict[str, float]]: """ Compute DAPO loss with dynamic clipping.
Loss = -min(ratio * A, clip(ratio) * A) + beta * KL
Args: policy_log_probs: Current policy log probs old_log_probs: Old policy log probs (from sampling) ref_log_probs: Reference model log probs advantages: Advantage estimates beta: KL coefficient clip_coef: Clipping coefficient
Returns: loss: Total loss metrics: Dictionary of loss components """ ratio = torch.exp(policy_log_probs - old_log_probs) ref_ratio = torch.exp(policy_log_probs - ref_log_probs)
clipped_ratio = torch.clamp(ratio, 1 - clip_coef, 1 + clip_coef) policy_loss1 = ratio * advantages policy_loss2 = clipped_ratio * advantages policy_loss = -torch.min(policy_loss1, policy_loss2).mean()
kl_loss = beta * (ref_log_probs - policy_log_probs).mean()
entropy = -(policy_log_probs.exp() * policy_log_probs).sum(dim=-1).mean()
total_loss = policy_loss + kl_loss - 0.01 * entropy
metrics = { "total_loss": total_loss.item(), "policy_loss": policy_loss.item(), "kl_loss": kl_loss.item(), "entropy": entropy.item(), "ratio_mean": ratio.mean().item(), "ratio_std": ratio.std().item() if ratio.numel() > 1 else 0, }
return total_loss, metrics
class DAPORewardFunction: """ Reward function for DAPO training.
Combines multiple reward signals: 1. Format reward: Proper response formatting 2. Quality reward: Response quality metrics 3. Progress reward: Reward for improvement over iterations """
def __init__(self, progress_tracker=None): self.progress_tracker = progress_tracker
def compute(self, prompt: str, response: str, iteration: int = 0) -> float: """ Compute reward for a prompt-response pair.
Args: prompt: Input prompt response: Generated response iteration: Current training iteration
Returns: reward: Scalar reward value """ reward = 0.0
if response.strip().endswith('.') or response.strip().endswith('?') or response.strip().endswith('!'): reward += 0.2
length = len(response) if 20 <= length <= 300: reward += 0.2 elif 10 <= length < 20: reward += 0.1 elif length > 300: reward -= 0.1
words = response.split() if len(words) > 0: unique_ratio = len(set(words)) / len(words) if unique_ratio > 0.7: reward += 0.2 elif unique_ratio < 0.3: reward -= 0.2
if self.progress_tracker is not None: progress_reward = self.progress_tracker.get_progress_reward(prompt, iteration) reward += progress_reward * 0.2
return reward
class ProgressTracker: """ Track training progress and provide progress-based rewards. """
def __init__(self): self.prompt_rewards = {} self.iteration = 0
def update(self, prompt: str, reward: float): """Update reward history for a prompt.""" if prompt not in self.prompt_rewards: self.prompt_rewards[prompt] = [] self.prompt_rewards[prompt].append(reward) self.iteration += 1
def get_progress_reward(self, prompt: str, iteration: int) -> float: """ Get progress reward based on improvement history.
Args: prompt: Input prompt iteration: Current iteration
Returns: progress_reward: Reward based on improvement """ if prompt not in self.prompt_rewards or len(self.prompt_rewards[prompt]) < 2: return 0.0
history = self.prompt_rewards[prompt] if len(history) >= 2: improvement = history[-1] - history[-2] return max(0, improvement)
return 0.0
class DAPOTrainer: """ DAPO Trainer implementing Dynamic Actor Policy Optimization. """
def __init__(self, config: DAPOConfig): self.config = config
torch.manual_seed(config.seed) np.random.seed(config.seed)
self.augmenter = PromptAugmenter() if config.use_augmentation else None self.temperature_controller = DynamicTemperatureController(config) self.progress_tracker = ProgressTracker() self.reward_function = DAPORewardFunction(self.progress_tracker)
print("Loading policy model...") self.policy_model = AutoModelForCausalLM.from_pretrained( config.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, )
print("Loading reference model...") self.ref_model = AutoModelForCausalLM.from_pretrained( config.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.ref_model.eval() for param in self.ref_model.parameters(): param.requires_grad = False
self.tokenizer = AutoTokenizer.from_pretrained( config.model_path, trust_remote_code=True, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
self.accelerator = Accelerator()
self.optimizer = AdamW( self.policy_model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, )
self.policy_model, self.optimizer = self.accelerator.prepare( self.policy_model, self.optimizer ) self.ref_model.to(self.accelerator.device)
self.global_step = 0 self.best_reward = -float('inf')
def generate_samples( self, prompts: List[str], temperature: float, ) -> Tuple[List[str], List[str], List[torch.Tensor]]: """ Generate samples with dynamic temperature.
Args: prompts: List of input prompts temperature: Sampling temperature
Returns: prompts_expanded: Expanded prompts responses: Generated responses old_log_probs: Log probabilities at generation time """ prompts_expanded = [] responses = [] old_log_probs = []
generation_config = GenerationConfig( max_new_tokens=self.config.max_response_length, temperature=temperature, top_p=self.config.top_p, do_sample=True, pad_token_id=self.tokenizer.pad_token_id, )
for prompt in prompts: if self.augmenter and np.random.random() < self.config.augmentation_ratio: augmented_prompts = self.augmenter.augment(prompt) prompt = np.random.choice(augmented_prompts)
inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.config.max_prompt_length ) inputs = {k: v.to(self.accelerator.device) for k, v in inputs.items()}
for _ in range(self.config.num_samples_per_prompt): with torch.no_grad(): outputs = self.policy_model.generate( **inputs, generation_config=generation_config, )
response = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, )
full_text = prompt + response full_inputs = self.tokenizer(full_text, return_tensors="pt") full_inputs = {k: v.to(self.accelerator.device) for k, v in full_inputs.items()}
prompt_length = inputs["input_ids"].shape[1]
with torch.no_grad(): model_outputs = self.policy_model(**full_inputs) logits = model_outputs.logits
response_logits = logits[0, prompt_length-1:-1, :] response_tokens = full_inputs["input_ids"][0, prompt_length:]
log_prob = F.log_softmax(response_logits, dim=-1) token_log_probs = log_prob[range(len(response_tokens)), response_tokens] total_log_prob = token_log_probs.sum()
prompts_expanded.append(prompt) responses.append(response) old_log_probs.append(total_log_prob)
return prompts_expanded, responses, torch.stack(old_log_probs)
def compute_log_probs( self, model: nn.Module, prompts: List[str], responses: List[str], ) -> torch.Tensor: """ Compute log probabilities for prompt-response pairs. """ log_probs = []
for prompt, response in zip(prompts, responses): full_text = prompt + response inputs = self.tokenizer(full_text, return_tensors="pt", truncation=True) inputs = {k: v.to(self.accelerator.device) for k, v in inputs.items()}
prompt_inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True) prompt_length = prompt_inputs["input_ids"].shape[1]
with torch.no_grad() if model == self.ref_model else torch.enable_grad(): outputs = model(**inputs) logits = outputs.logits
response_logits = logits[0, prompt_length-1:-1, :] response_tokens = inputs["input_ids"][0, prompt_length:]
log_prob = F.log_softmax(response_logits, dim=-1) token_log_probs = log_prob[range(len(response_tokens)), response_tokens] total_log_prob = token_log_probs.sum()
log_probs.append(total_log_prob)
return torch.stack(log_probs)
def compute_rewards( self, prompts: List[str], responses: List[str], ) -> torch.Tensor: """ Compute rewards for all samples. """ rewards = [] for prompt, response in zip(prompts, responses): reward = self.reward_function.compute(prompt, response, self.global_step) rewards.append(reward)
self.progress_tracker.update(prompt, reward)
return torch.tensor(rewards, dtype=torch.float32, device=self.accelerator.device)
def train_step(self, prompts: List[str]) -> Dict[str, float]: """ Perform one DAPO training step.
Args: prompts: List of input prompts
Returns: metrics: Training metrics """ temperature = self.temperature_controller.get_temperature()
all_prompts, all_responses, old_log_probs = self.generate_samples( prompts, temperature )
rewards = self.compute_rewards(all_prompts, all_responses)
temperature = self.temperature_controller.update(rewards.tolist())
advantages = compute_dapo_advantage( rewards, values=None, gamma=self.config.gamma, )
policy_log_probs = self.compute_log_probs(self.policy_model, all_prompts, all_responses)
ref_log_probs = self.compute_log_probs(self.ref_model, all_prompts, all_responses)
loss, metrics = compute_dapo_loss( policy_log_probs, old_log_probs.to(self.accelerator.device), ref_log_probs, advantages, beta=self.config.beta, clip_coef=self.config.clip_coef, )
metrics["mean_reward"] = rewards.mean().item() metrics["std_reward"] = rewards.std().item() if rewards.numel() > 1 else 0 metrics["temperature"] = temperature
self.accelerator.backward(loss) self.optimizer.step() self.optimizer.zero_grad()
if metrics["mean_reward"] > self.best_reward: self.best_reward = metrics["mean_reward"]
self.global_step += 1
return metrics
def train(self, dataset: Dataset): """ Main training loop.
Args: dataset: Training dataset """ print(f"Starting DAPO training on {len(dataset)} samples...") print(f"Dynamic temperature: {self.config.initial_temperature}") print(f"Data augmentation: {self.config.use_augmentation}")
dataloader = DataLoader( dataset, batch_size=self.config.batch_size, shuffle=True, )
num_batches = len(dataloader) * self.config.num_epochs
for epoch in range(self.config.num_epochs): print(f"\nEpoch {epoch + 1}/{self.config.num_epochs}")
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Training")): if "instruction" in batch: prompts = batch["instruction"] elif "prompt" in batch: prompts = batch["prompt"] else: prompts = batch["text"]
if isinstance(prompts, torch.Tensor): prompts = prompts.tolist()
metrics = self.train_step(prompts)
if self.global_step % self.config.logging_steps == 0: print( f"Step {self.global_step}/{num_batches}: " f"loss={metrics['total_loss']:.4f}, " f"reward={metrics['mean_reward']:.4f}±{metrics['std_reward']:.4f}, " f"temp={metrics['temperature']:.3f}, " f"kl={metrics['kl_loss']:.4f}" )
if self.global_step % self.config.save_steps == 0: self.save_checkpoint(self.global_step)
self.save_checkpoint(self.global_step, final=True) print(f"Training complete! Model saved to: {self.config.output_dir}") print(f"Best reward achieved: {self.best_reward:.4f}")
def save_checkpoint(self, step: int, final: bool = False): """ Save model checkpoint. """ save_dir = ( self.config.output_dir if final else os.path.join(self.config.output_dir, f"checkpoint-{step}") )
os.makedirs(save_dir, exist_ok=True)
unwrapped_model = self.accelerator.unwrap_model(self.policy_model) unwrapped_model.save_pretrained(save_dir) self.tokenizer.save_pretrained(save_dir)
state = { "global_step": self.global_step, "best_reward": self.best_reward, "temperature": self.temperature_controller.get_temperature(), } with open(os.path.join(save_dir, "training_state.json"), "w") as f: json.dump(state, f)
print(f"Checkpoint saved to: {save_dir}")
def prepare_dataset(config: DAPOConfig) -> Dataset: """ Load and prepare dataset for DAPO training. """ print(f"Loading dataset: {config.dataset_name}")
try: dataset = load_dataset(config.dataset_name, split="train") except Exception as e: print(f"Failed to load dataset: {e}") local_path = Path(DATA_DIR) / f"{config.dataset_name}.json" if local_path.exists(): with open(local_path, "r") as f: data = json.load(f) dataset = Dataset.from_list(data) else: raise ValueError(f"Dataset not found: {config.dataset_name}")
if config.max_samples > 0 and len(dataset) > config.max_samples: dataset = dataset.select(range(config.max_samples))
def format_prompt(example): if "instruction" in example: instruction = example["instruction"] input_text = example.get("input", "") if input_text: return {"prompt": f"Instruction: {instruction}\nInput: {input_text}\nResponse:"} else: return {"prompt": f"Instruction: {instruction}\nResponse:"} elif "prompt" in example: return {"prompt": example["prompt"]} else: return {"prompt": str(example)}
dataset = dataset.map(format_prompt)
print(f"Dataset prepared: {len(dataset)} samples") return dataset
def main(): parser = argparse.ArgumentParser(description="DAPO Training Script") parser.add_argument("--model_path", type=str, default=MODEL_PATH) parser.add_argument("--output_dir", type=str, default=OUTPUT_DIR) parser.add_argument("--dataset_name", type=str, default=DEFAULT_DATASET) parser.add_argument("--max_samples", type=int, default=1000) parser.add_argument("--num_samples_per_prompt", type=int, default=4) parser.add_argument("--initial_temperature", type=float, default=1.0) parser.add_argument("--beta", type=float, default=0.1) parser.add_argument("--learning_rate", type=float, default=1e-6) parser.add_argument("--num_epochs", type=int, default=1) parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--use_augmentation", type=bool, default=True) parser.add_argument("--deepspeed", type=str, default=None) parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
config = DAPOConfig( model_path=args.model_path, output_dir=args.output_dir, dataset_name=args.dataset_name, max_samples=args.max_samples, num_samples_per_prompt=args.num_samples_per_prompt, initial_temperature=args.initial_temperature, beta=args.beta, learning_rate=args.learning_rate, num_epochs=args.num_epochs, batch_size=args.batch_size, use_augmentation=args.use_augmentation, deepspeed=args.deepspeed, seed=args.seed, )
print("=" * 60) print("DAPO Training Configuration") print("=" * 60) print(f"Model: {config.model_path}") print(f"Output: {config.output_dir}") print(f"Dataset: {config.dataset_name}") print(f"Samples per prompt: {config.num_samples_per_prompt}") print(f"Initial temperature: {config.initial_temperature}") print(f"Data augmentation: {config.use_augmentation}") print(f"Beta (KL coefficient): {config.beta}") print("=" * 60)
dataset = prepare_dataset(config)
trainer = DAPOTrainer(config) trainer.train(dataset)
if __name__ == "__main__": main()
|