Skip to main content

Training Orchestration: SFT to DPO Pipeline

Individual training scripts work in isolation. Production systems require orchestration—coordinating multiple stages, managing checkpoints, validating outputs, and handling failures gracefully.

This lesson teaches you to build an automated training pipeline that takes your Task API model from base through SFT to aligned. By the end, you'll have a reproducible, auditable training workflow.

The Training Pipeline

┌─────────────────────────────────────────────────────────────────────────────┐
│ TRAINING ORCHESTRATION PIPELINE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
│ │ STAGE 1 │ │ STAGE 2 │ │ STAGE 3 │ │
│ │ SFT │────────▶│ Merge │────────▶│ DPO Alignment │ │
│ │ │ │ Adapter │ │ │ │
│ │ LoRA/QLoRA │ │ │ │ Preference │ │
│ │ Training │ │ Base + │ │ Optimization │ │
│ │ │ │ LoRA │ │ │ │
│ └──────┬───────┘ └──────┬───────┘ └────────┬─────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────────┐ │
│ │ Validation │ │ Validation │ │ Validation │ │
│ │ Gate 1 │ │ Gate 2 │ │ Gate 3 │ │
│ │ │ │ │ │ │ │
│ │ - Loss < X │ │ - Merge OK │ │ - Alignment OK │ │
│ │ - No NaN │ │ - Size OK │ │ - Safety Pass │ │
│ └──────────────┘ └──────────────┘ └──────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ FINAL MODEL │ │
│ │ │ │
│ │ task-api-v1.0 │ │
│ │ (aligned) │ │
│ └──────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘

Project Structure

training-pipeline/
├── pipeline/
│ ├── __init__.py
│ ├── config.py # Pipeline configuration
│ ├── stages/
│ │ ├── sft.py # SFT training stage
│ │ ├── merge.py # Adapter merge stage
│ │ └── dpo.py # DPO alignment stage
│ ├── validators/
│ │ ├── sft_validator.py
│ │ ├── merge_validator.py
│ │ └── dpo_validator.py
│ └── orchestrator.py # Main orchestration
├── configs/
│ ├── sft_config.yaml
│ ├── dpo_config.yaml
│ └── pipeline_config.yaml
├── artifacts/ # Training outputs
├── run_pipeline.py
└── requirements.txt

Pipeline Configuration

Create pipeline/config.py:

from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
from pathlib import Path
import yaml

@dataclass
class StageConfig:
name: str
enabled: bool = True
config_path: Optional[str] = None
timeout_minutes: int = 120
retry_count: int = 1

@dataclass
class ValidationConfig:
max_loss: float = 2.0
min_accuracy: float = 0.7
check_nan: bool = True
custom_checks: List[str] = field(default_factory=list)

@dataclass
class PipelineConfig:
pipeline_id: str
base_model: str
output_dir: Path
stages: List[StageConfig]
validation: ValidationConfig

# Resource settings
gpu_memory_fraction: float = 0.9
checkpoint_every_n_steps: int = 100

# Artifact settings
save_intermediate: bool = True
upload_to_hub: bool = False
hub_repo_id: Optional[str] = None

def load_pipeline_config(config_path: str) -> PipelineConfig:
"""Load pipeline configuration from YAML."""
with open(config_path) as f:
data = yaml.safe_load(f)

stages = [StageConfig(**s) for s in data.get("stages", [])]
validation = ValidationConfig(**data.get("validation", {}))

return PipelineConfig(
pipeline_id=data["pipeline_id"],
base_model=data["base_model"],
output_dir=Path(data["output_dir"]),
stages=stages,
validation=validation,
gpu_memory_fraction=data.get("gpu_memory_fraction", 0.9),
checkpoint_every_n_steps=data.get("checkpoint_every_n_steps", 100),
save_intermediate=data.get("save_intermediate", True),
upload_to_hub=data.get("upload_to_hub", False),
hub_repo_id=data.get("hub_repo_id")
)

Create configs/pipeline_config.yaml:

pipeline_id: task-api-v1.0
base_model: unsloth/Qwen2.5-3B-Instruct-bnb-4bit
output_dir: ./artifacts

stages:
- name: sft
enabled: true
config_path: ./configs/sft_config.yaml
timeout_minutes: 180
retry_count: 2

- name: merge
enabled: true
timeout_minutes: 30
retry_count: 1

- name: dpo
enabled: true
config_path: ./configs/dpo_config.yaml
timeout_minutes: 120
retry_count: 2

validation:
max_loss: 2.0
min_accuracy: 0.7
check_nan: true
custom_checks:
- output_format_valid
- no_harmful_content

gpu_memory_fraction: 0.9
checkpoint_every_n_steps: 100
save_intermediate: true
upload_to_hub: false

Stage Implementations

Stage 1: SFT Training

Create pipeline/stages/sft.py:

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, Optional
import json
import yaml
from datetime import datetime

@dataclass
class SFTResult:
success: bool
adapter_path: Path
final_loss: float
training_time_seconds: float
metrics: Dict[str, Any]
error: Optional[str] = None

def run_sft_stage(
base_model: str,
dataset_path: str,
output_dir: Path,
config_path: str
) -> SFTResult:
"""Execute SFT training stage."""
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
import time

start_time = time.time()
output_dir = Path(output_dir) / "sft"
output_dir.mkdir(parents=True, exist_ok=True)

# Load configuration
with open(config_path) as f:
config = yaml.safe_load(f)

try:
# Load model with LoRA
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model,
max_seq_length=config.get("max_seq_length", 2048),
load_in_4bit=True
)

model = FastLanguageModel.get_peft_model(
model,
r=config.get("lora_r", 16),
lora_alpha=config.get("lora_alpha", 32),
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=config.get("lora_dropout", 0.05)
)

# Load dataset
dataset = load_dataset("json", data_files=dataset_path)["train"]

# Training arguments
training_args = TrainingArguments(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=config.get("epochs", 3),
per_device_train_batch_size=config.get("batch_size", 4),
gradient_accumulation_steps=config.get("gradient_accumulation", 4),
learning_rate=config.get("learning_rate", 2e-4),
warmup_ratio=config.get("warmup_ratio", 0.1),
logging_steps=10,
save_steps=config.get("save_steps", 100),
save_total_limit=3,
fp16=True,
report_to="none"
)

# Train
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=training_args,
max_seq_length=config.get("max_seq_length", 2048)
)

train_result = trainer.train()

# Save adapter
adapter_path = output_dir / "adapter"
model.save_pretrained(adapter_path)
tokenizer.save_pretrained(adapter_path)

# Record metrics
training_time = time.time() - start_time
metrics = {
"final_loss": train_result.training_loss,
"total_steps": train_result.global_step,
"epochs_completed": train_result.num_train_epochs,
"samples_processed": len(dataset) * train_result.num_train_epochs
}

# Save metrics
with open(output_dir / "metrics.json", "w") as f:
json.dump(metrics, f, indent=2)

return SFTResult(
success=True,
adapter_path=adapter_path,
final_loss=train_result.training_loss,
training_time_seconds=training_time,
metrics=metrics
)

except Exception as e:
return SFTResult(
success=False,
adapter_path=Path(),
final_loss=float("inf"),
training_time_seconds=time.time() - start_time,
metrics={},
error=str(e)
)

Stage 2: Adapter Merge

Create pipeline/stages/merge.py:

from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import json
import time

@dataclass
class MergeResult:
success: bool
merged_model_path: Path
model_size_gb: float
merge_time_seconds: float
error: Optional[str] = None

def run_merge_stage(
base_model: str,
adapter_path: Path,
output_dir: Path
) -> MergeResult:
"""Merge LoRA adapter with base model."""
from unsloth import FastLanguageModel
import shutil

start_time = time.time()
output_dir = Path(output_dir) / "merged"
output_dir.mkdir(parents=True, exist_ok=True)

try:
# Load base model with adapter
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=str(adapter_path),
max_seq_length=2048,
load_in_4bit=False # Full precision for merge
)

# Merge adapter into base weights
merged_model = model.merge_and_unload()

# Save merged model
merged_path = output_dir / "model"
merged_model.save_pretrained(merged_path)
tokenizer.save_pretrained(merged_path)

# Calculate size
total_size = sum(f.stat().st_size for f in merged_path.rglob("*") if f.is_file())
size_gb = total_size / (1024**3)

merge_time = time.time() - start_time

# Save merge info
with open(output_dir / "merge_info.json", "w") as f:
json.dump({
"base_model": base_model,
"adapter_path": str(adapter_path),
"merged_path": str(merged_path),
"size_gb": size_gb,
"merge_time_seconds": merge_time
}, f, indent=2)

return MergeResult(
success=True,
merged_model_path=merged_path,
model_size_gb=size_gb,
merge_time_seconds=merge_time
)

except Exception as e:
return MergeResult(
success=False,
merged_model_path=Path(),
model_size_gb=0,
merge_time_seconds=time.time() - start_time,
error=str(e)
)

Stage 3: DPO Alignment

Create pipeline/stages/dpo.py:

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, Optional
import json
import yaml
import time

@dataclass
class DPOResult:
success: bool
aligned_model_path: Path
final_loss: float
training_time_seconds: float
metrics: Dict[str, Any]
error: Optional[str] = None

def run_dpo_stage(
model_path: Path,
preference_dataset_path: str,
output_dir: Path,
config_path: str
) -> DPOResult:
"""Execute DPO alignment stage."""
from unsloth import FastLanguageModel
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset

start_time = time.time()
output_dir = Path(output_dir) / "dpo"
output_dir.mkdir(parents=True, exist_ok=True)

# Load configuration
with open(config_path) as f:
config = yaml.safe_load(f)

try:
# Load merged model with new LoRA for DPO
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=str(model_path),
max_seq_length=config.get("max_seq_length", 2048),
load_in_4bit=True
)

model = FastLanguageModel.get_peft_model(
model,
r=config.get("lora_r", 8), # Smaller for alignment
lora_alpha=config.get("lora_alpha", 16),
target_modules=["q_proj", "v_proj"],
lora_dropout=config.get("lora_dropout", 0.05)
)

# Load preference dataset
dataset = load_dataset("json", data_files=preference_dataset_path)["train"]

# DPO configuration
dpo_config = DPOConfig(
output_dir=str(output_dir / "checkpoints"),
num_train_epochs=config.get("epochs", 1),
per_device_train_batch_size=config.get("batch_size", 2),
gradient_accumulation_steps=config.get("gradient_accumulation", 8),
learning_rate=config.get("learning_rate", 5e-5),
beta=config.get("beta", 0.1),
warmup_ratio=0.1,
logging_steps=10,
save_steps=50,
fp16=True,
report_to="none"
)

# Train
trainer = DPOTrainer(
model=model,
ref_model=None, # Uses implicit reference
args=dpo_config,
train_dataset=dataset,
tokenizer=tokenizer
)

train_result = trainer.train()

# Save aligned model
aligned_path = output_dir / "aligned"
model.save_pretrained(aligned_path)
tokenizer.save_pretrained(aligned_path)

training_time = time.time() - start_time
metrics = {
"final_loss": train_result.training_loss,
"total_steps": train_result.global_step,
"beta_used": config.get("beta", 0.1)
}

with open(output_dir / "metrics.json", "w") as f:
json.dump(metrics, f, indent=2)

return DPOResult(
success=True,
aligned_model_path=aligned_path,
final_loss=train_result.training_loss,
training_time_seconds=training_time,
metrics=metrics
)

except Exception as e:
return DPOResult(
success=False,
aligned_model_path=Path(),
final_loss=float("inf"),
training_time_seconds=time.time() - start_time,
metrics={},
error=str(e)
)

Validation Gates

Create pipeline/validators/sft_validator.py:

from dataclasses import dataclass
from typing import List, Optional
from pathlib import Path
import json
import math

@dataclass
class ValidationResult:
passed: bool
checks: List[dict]
errors: List[str]

def validate_sft_stage(
adapter_path: Path,
metrics_path: Path,
max_loss: float = 2.0
) -> ValidationResult:
"""Validate SFT training results."""
checks = []
errors = []

# Check 1: Adapter files exist
required_files = ["adapter_config.json", "adapter_model.safetensors"]
files_exist = all((adapter_path / f).exists() for f in required_files)
checks.append({
"name": "adapter_files_exist",
"passed": files_exist,
"details": f"Required files: {required_files}"
})
if not files_exist:
errors.append("Missing adapter files")

# Check 2: Loss below threshold
if metrics_path.exists():
with open(metrics_path) as f:
metrics = json.load(f)

final_loss = metrics.get("final_loss", float("inf"))
loss_ok = final_loss < max_loss and not math.isnan(final_loss)
checks.append({
"name": "loss_acceptable",
"passed": loss_ok,
"details": f"Final loss: {final_loss:.4f}, threshold: {max_loss}"
})
if not loss_ok:
errors.append(f"Loss {final_loss:.4f} exceeds threshold {max_loss}")

# Check 3: No NaN in metrics
has_nan = any(math.isnan(v) if isinstance(v, float) else False
for v in metrics.values())
checks.append({
"name": "no_nan_values",
"passed": not has_nan,
"details": "Checking metrics for NaN values"
})
if has_nan:
errors.append("NaN values detected in metrics")
else:
errors.append("Metrics file not found")

return ValidationResult(
passed=len(errors) == 0,
checks=checks,
errors=errors
)

Main Orchestrator

Create pipeline/orchestrator.py:

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, List, Optional
from datetime import datetime
import json
import logging

from .config import PipelineConfig, load_pipeline_config
from .stages.sft import run_sft_stage, SFTResult
from .stages.merge import run_merge_stage, MergeResult
from .stages.dpo import run_dpo_stage, DPOResult
from .validators.sft_validator import validate_sft_stage

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class PipelineResult:
success: bool
pipeline_id: str
final_model_path: Optional[Path]
stages_completed: List[str]
total_time_seconds: float
stage_results: Dict[str, Any]
errors: List[str]

class TrainingOrchestrator:
"""Orchestrates the complete training pipeline."""

def __init__(self, config: PipelineConfig):
self.config = config
self.config.output_dir.mkdir(parents=True, exist_ok=True)
self.stage_results: Dict[str, Any] = {}
self.errors: List[str] = []

def run(
self,
sft_dataset: str,
preference_dataset: str
) -> PipelineResult:
"""Execute the complete pipeline."""
import time

start_time = time.time()
stages_completed = []
current_model_path = None

logger.info(f"Starting pipeline: {self.config.pipeline_id}")

# Stage 1: SFT
sft_stage = next((s for s in self.config.stages if s.name == "sft"), None)
if sft_stage and sft_stage.enabled:
logger.info("Running SFT stage...")

sft_result = run_sft_stage(
base_model=self.config.base_model,
dataset_path=sft_dataset,
output_dir=self.config.output_dir,
config_path=sft_stage.config_path
)

self.stage_results["sft"] = {
"success": sft_result.success,
"loss": sft_result.final_loss,
"time_seconds": sft_result.training_time_seconds
}

if not sft_result.success:
self.errors.append(f"SFT failed: {sft_result.error}")
return self._create_result(stages_completed, None, time.time() - start_time)

# Validate SFT
validation = validate_sft_stage(
sft_result.adapter_path,
self.config.output_dir / "sft" / "metrics.json",
self.config.validation.max_loss
)

if not validation.passed:
self.errors.extend(validation.errors)
return self._create_result(stages_completed, None, time.time() - start_time)

stages_completed.append("sft")
current_model_path = sft_result.adapter_path
logger.info(f"SFT complete. Loss: {sft_result.final_loss:.4f}")

# Stage 2: Merge
merge_stage = next((s for s in self.config.stages if s.name == "merge"), None)
if merge_stage and merge_stage.enabled and current_model_path:
logger.info("Running merge stage...")

merge_result = run_merge_stage(
base_model=self.config.base_model,
adapter_path=current_model_path,
output_dir=self.config.output_dir
)

self.stage_results["merge"] = {
"success": merge_result.success,
"size_gb": merge_result.model_size_gb,
"time_seconds": merge_result.merge_time_seconds
}

if not merge_result.success:
self.errors.append(f"Merge failed: {merge_result.error}")
return self._create_result(stages_completed, None, time.time() - start_time)

stages_completed.append("merge")
current_model_path = merge_result.merged_model_path
logger.info(f"Merge complete. Size: {merge_result.model_size_gb:.2f} GB")

# Stage 3: DPO
dpo_stage = next((s for s in self.config.stages if s.name == "dpo"), None)
if dpo_stage and dpo_stage.enabled and current_model_path:
logger.info("Running DPO stage...")

dpo_result = run_dpo_stage(
model_path=current_model_path,
preference_dataset_path=preference_dataset,
output_dir=self.config.output_dir,
config_path=dpo_stage.config_path
)

self.stage_results["dpo"] = {
"success": dpo_result.success,
"loss": dpo_result.final_loss,
"time_seconds": dpo_result.training_time_seconds
}

if not dpo_result.success:
self.errors.append(f"DPO failed: {dpo_result.error}")
return self._create_result(stages_completed, None, time.time() - start_time)

stages_completed.append("dpo")
current_model_path = dpo_result.aligned_model_path
logger.info(f"DPO complete. Loss: {dpo_result.final_loss:.4f}")

total_time = time.time() - start_time
return self._create_result(stages_completed, current_model_path, total_time)

def _create_result(
self,
stages_completed: List[str],
final_path: Optional[Path],
total_time: float
) -> PipelineResult:
"""Create final pipeline result."""
result = PipelineResult(
success=len(self.errors) == 0,
pipeline_id=self.config.pipeline_id,
final_model_path=final_path,
stages_completed=stages_completed,
total_time_seconds=total_time,
stage_results=self.stage_results,
errors=self.errors
)

# Save result
result_path = self.config.output_dir / "pipeline_result.json"
with open(result_path, "w") as f:
json.dump({
"success": result.success,
"pipeline_id": result.pipeline_id,
"final_model_path": str(result.final_model_path) if result.final_model_path else None,
"stages_completed": result.stages_completed,
"total_time_seconds": result.total_time_seconds,
"stage_results": result.stage_results,
"errors": result.errors,
"timestamp": datetime.now().isoformat()
}, f, indent=2)

return result

Running the Pipeline

Create run_pipeline.py:

#!/usr/bin/env python
from pipeline.config import load_pipeline_config
from pipeline.orchestrator import TrainingOrchestrator
import argparse
import json

def main():
parser = argparse.ArgumentParser(description="Run training pipeline")
parser.add_argument("--config", default="configs/pipeline_config.yaml")
parser.add_argument("--sft-dataset", required=True)
parser.add_argument("--preference-dataset", required=True)
args = parser.parse_args()

config = load_pipeline_config(args.config)
orchestrator = TrainingOrchestrator(config)

print(f"Starting pipeline: {config.pipeline_id}")
print(f"Base model: {config.base_model}")
print(f"Output: {config.output_dir}")
print()

result = orchestrator.run(
sft_dataset=args.sft_dataset,
preference_dataset=args.preference_dataset
)

print("\n" + "="*50)
print("PIPELINE RESULT")
print("="*50)
print(f"Success: {result.success}")
print(f"Stages completed: {', '.join(result.stages_completed)}")
print(f"Total time: {result.total_time_seconds/60:.1f} minutes")

if result.final_model_path:
print(f"Final model: {result.final_model_path}")

if result.errors:
print(f"\nErrors:")
for error in result.errors:
print(f" - {error}")

if __name__ == "__main__":
main()

Output:

Starting pipeline: task-api-v1.0
Base model: unsloth/Qwen2.5-3B-Instruct-bnb-4bit
Output: ./artifacts

INFO:Running SFT stage...
INFO:SFT complete. Loss: 0.8234
INFO:Running merge stage...
INFO:Merge complete. Size: 6.24 GB
INFO:Running DPO stage...
INFO:DPO complete. Loss: 0.3421

==================================================
PIPELINE RESULT
==================================================
Success: True
Stages completed: sft, merge, dpo
Total time: 45.3 minutes
Final model: ./artifacts/dpo/aligned

Try With AI

Prompt 1: Add Resumption Support

My training pipeline sometimes fails mid-stage due to OOM or timeouts.
I need to add pipeline resumption capability:

1. Checkpoint the pipeline state after each stage
2. Detect incomplete runs on startup
3. Resume from the last successful stage
4. Handle partial stage completion (mid-training checkpoints)

Show me how to modify my orchestrator to support this, including
the checkpoint format and resumption logic.

What you're learning: Implementing fault-tolerant pipeline execution.

Prompt 2: Add Parallel Validation

My validation gates run sequentially after each stage. I want to:

1. Run some validations in parallel (format checks, size checks)
2. Keep blocking validations sequential (loss threshold, safety)
3. Add timeout handling for slow validations
4. Collect and report all validation results together

Modify my validator structure to support this pattern.

What you're learning: Optimizing validation workflows with parallelization.

Prompt 3: Design Multi-Run Experimentation

I want to run multiple pipeline configurations to find optimal hyperparameters:

1. Define a sweep space (learning rates, LoRA ranks, epochs)
2. Run pipelines in parallel (if GPU memory allows) or sequential
3. Track results in a comparison table
4. Identify best configuration based on validation metrics

Design an experimentation layer on top of my training pipeline.

What you're learning: Building systematic experimentation frameworks for LLMOps.