Reasoning Distillation from Larger Models
Your 3B model handles tool-calling reliably. But when tasks require multi-step reasoning—"Which tasks should I prioritize given my meeting schedule and deadlines?"—it struggles. Meanwhile, GPT-4 or Claude handles these effortlessly.
Knowledge distillation offers a path: train your small model on the reasoning traces of a larger model. The teacher "shows its work," and the student learns to replicate that reasoning process.
This lesson teaches you to distill reasoning capability from larger models into your Task API agent, giving it GPT-4-level task planning on a 3B budget.
The Reasoning Gap
What Small Models Miss
Your fine-tuned 3B model excels at:
- Single-tool selection (>95% accuracy)
- Argument extraction from clear requests
- JSON formatting reliability
But it struggles with:
- Multi-step task decomposition
- Priority reasoning across constraints
- Context-dependent decision making
Example: Where 3B Falls Short
User: "I have a client call at 2 PM and three tasks due today. The budget review takes 2 hours, expense report takes 30 minutes, and email cleanup takes 15 minutes. Which should I tackle first?"
3B Model Response:
{"tool_calls": [{"function": {"name": "list_tasks", "arguments": "{}"}}]}
The model retreats to a safe tool call instead of reasoning about prioritization.
GPT-4 Response:
Let me think through this:
1. Time constraint: Client call at 2 PM limits available blocks
2. Task durations: Budget review (2h) + expense (30m) + email (15m) = 2h 45m total
3. Current time: Assuming morning start, I have ~4 hours before 2 PM
Prioritization logic:
- Budget review (2h) is the biggest; if I wait, I might not finish before the call
- Expense report (30m) is quick; can fit before or after the big task
- Email cleanup (15m) is a buffer task
Recommendation: Start with budget review immediately (2 hours), then expense
report (30 min). Email cleanup can happen after the client call or in small
breaks. This ensures the time-sensitive, high-effort work is completed first.
GPT-4 reasons through constraints explicitly before suggesting actions. We want to transfer this reasoning capability.
Knowledge Distillation: Core Concept
The Teacher-Student Paradigm
┌─────────────────────────────────────────────────────────────────────┐
│ KNOWLEDGE DISTILLATION │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ TEACHER │ ──Reasoning───▶ │ STUDENT │ │
│ │ (GPT-4) │ Traces │ (3B) │ │
│ └──────────────┘ └──────────────┘ │
│ │
│ Capabilities: Learns: │
│ - Multi-step reasoning - Reasoning patterns │
│ - Constraint analysis - Explicit thought process │
│ - Priority inference - Task decomposition │
│ │
│ Cost: $0.01/request Cost: $0.001/request │
│ Latency: 1-2 seconds Latency: 200ms │
│ │
└─────────────────────────────────────────────────────────────────────┘
Why Chain-of-Thought Matters
Distillation works better when the teacher explicitly shows reasoning:
| Transfer Type | Example | Transfer Quality |
|---|---|---|
| Answer only | "Do budget review first" | Poor—no reasoning learned |
| Light reasoning | "Budget review first because it takes longest" | Moderate |
| Chain-of-thought | Full reasoning trace (time analysis, prioritization logic) | Best |
Explicit reasoning traces make internal knowledge external, enabling the student to learn the process, not just the conclusion.
Generating Distillation Data
Step 1: Define Reasoning Tasks
Identify tasks where your small model fails but the teacher succeeds:
REASONING_TASKS = [
{
"type": "prioritization",
"prompt": "I have tasks X, Y, Z with deadlines A, B, C. Which should I do first?",
"requires": ["time estimation", "deadline awareness", "priority logic"]
},
{
"type": "scheduling",
"prompt": "My calendar has gaps at [times]. Can I fit [task] before [event]?",
"requires": ["duration estimation", "buffer awareness", "constraint satisfaction"]
},
{
"type": "decomposition",
"prompt": "Break down [large task] into actionable subtasks",
"requires": ["task analysis", "dependency identification", "sequencing"]
},
{
"type": "conflict_resolution",
"prompt": "Tasks A and B both need resources C. How should I proceed?",
"requires": ["resource awareness", "tradeoff analysis", "recommendation"]
},
]
Step 2: Generate Teacher Responses
from openai import OpenAI
import json
client = OpenAI()
def generate_reasoning_trace(task: dict, context: dict) -> str:
"""Get chain-of-thought reasoning from teacher model."""
system_prompt = """You are a productivity reasoning expert.
When given a task prioritization or scheduling question:
1. First, explicitly list the constraints and variables
2. Then, reason through the implications step by step
3. Finally, provide a clear recommendation with justification
Show your full reasoning process. This will be used to train a smaller model."""
user_prompt = f"""Context:
- Current tasks: {json.dumps(context['tasks'])}
- Calendar: {json.dumps(context['calendar'])}
- Current time: {context['current_time']}
Question: {task['prompt']}
Think through this step by step, showing your reasoning explicitly."""
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.3 # Lower temp for consistent reasoning
)
return response.choices[0].message.content
# Generate example
context = {
"tasks": [
{"title": "Budget review", "due": "5 PM today", "duration_hours": 2},
{"title": "Expense report", "due": "5 PM today", "duration_hours": 0.5},
{"title": "Email cleanup", "due": "end of week", "duration_hours": 0.25}
],
"calendar": [
{"title": "Client call", "time": "2:00 PM - 3:00 PM"}
],
"current_time": "9:00 AM"
}
reasoning = generate_reasoning_trace(
{"prompt": "Which task should I tackle first and why?"},
context
)
print(reasoning)
Output:
Let me analyze this systematically:
**Constraints:**
- Current time: 9:00 AM
- Client call: 2:00 PM - 3:00 PM (blocked)
- Budget review: 2 hours, due 5 PM
- Expense report: 30 min, due 5 PM
- Email cleanup: 15 min, due end of week
**Available time blocks:**
- 9:00 AM - 2:00 PM: 5 hours available
- 3:00 PM - 5:00 PM: 2 hours available
**Reasoning:**
1. Both budget review and expense report are due today (5 PM)
2. Budget review (2h) is the largest chunk; needs uninterrupted time
3. Morning block (5h) is best for budget review—enough buffer
4. Expense report (30m) fits anywhere; can do before call or after
5. Email cleanup isn't urgent (end of week)—lowest priority
**Optimal sequence:**
1. Budget review first (9 AM - 11 AM)—highest effort, benefits from fresh morning focus
2. Expense report (11 AM - 11:30 AM)—quick win before potential interruptions
3. Buffer (11:30 AM - 2 PM)—prep for call, handle unexpected items
4. Client call (2-3 PM)
5. Email cleanup after call or defer to tomorrow
**Recommendation:** Start with the budget review immediately. It's the largest
task and has the tightest effective deadline (must complete before 2 PM if
you want time to review before submitting at 5 PM).
Step 3: Format for Training
Convert teacher outputs into training format:
def format_distillation_example(
user_input: str,
context: dict,
teacher_reasoning: str,
final_action: dict = None
) -> dict:
"""Format distillation example for training."""
messages = [
{
"role": "system",
"content": """You are TaskMaster, a productivity assistant.
When asked about task prioritization or scheduling:
1. Analyze constraints explicitly
2. Reason through options step by step
3. Provide clear recommendation
4. If action needed, call appropriate tool
Always show your reasoning before taking action."""
},
{
"role": "user",
"content": f"Context:\n{json.dumps(context, indent=2)}\n\nQuestion: {user_input}"
},
{
"role": "assistant",
"content": teacher_reasoning
}
]
# Optionally add tool call after reasoning
if final_action:
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [final_action]
})
return {"messages": messages}
# Generate training example
example = format_distillation_example(
user_input="Which task should I tackle first and why?",
context=context,
teacher_reasoning=reasoning,
final_action={
"id": "call_001",
"type": "function",
"function": {
"name": "set_priority",
"arguments": json.dumps({
"task_id": "budget_review",
"priority": "immediate",
"reason": "Largest task, tightest effective deadline"
})
}
}
)
Step 4: Scale Data Generation
import random
from tqdm import tqdm
def generate_distillation_dataset(
task_templates: list,
num_examples: int = 200
) -> list:
"""Generate complete distillation dataset."""
examples = []
for _ in tqdm(range(num_examples)):
# Randomize context
context = generate_random_context()
# Select random task type
task = random.choice(task_templates)
# Get teacher reasoning
reasoning = generate_reasoning_trace(task, context)
# Format for training
example = format_distillation_example(
user_input=task["prompt"],
context=context,
teacher_reasoning=reasoning
)
examples.append(example)
return examples
# Generate dataset
dataset = generate_distillation_dataset(REASONING_TASKS, num_examples=200)
# Save
with open("distillation_dataset.jsonl", "w") as f:
for ex in dataset:
f.write(json.dumps(ex) + "\n")
print(f"Generated {len(dataset)} distillation examples")
Output:
100%|██████████████████████████████████████| 200/200 [15:32<00:00, 4.66s/it]
Generated 200 distillation examples
Training the Student Model
Mix Distillation with Existing Data
Don't replace your existing training data—augment it:
def create_mixed_dataset(
existing_data: str,
distillation_data: str,
ratio: float = 0.3 # 30% distillation examples
) -> str:
"""Mix distillation examples with existing training data."""
# Load existing
with open(existing_data) as f:
existing = [json.loads(line) for line in f]
# Load distillation
with open(distillation_data) as f:
distillation = [json.loads(line) for line in f]
# Calculate mix
target_distillation = int(len(existing) * ratio / (1 - ratio))
sampled_distillation = random.sample(
distillation,
min(target_distillation, len(distillation))
)
# Combine and shuffle
combined = existing + sampled_distillation
random.shuffle(combined)
# Save
output = existing_data.replace(".jsonl", "_mixed.jsonl")
with open(output, "w") as f:
for ex in combined:
f.write(json.dumps(ex) + "\n")
print(f"Mixed dataset: {len(existing)} existing + {len(sampled_distillation)} distillation = {len(combined)} total")
return output
mixed_path = create_mixed_dataset(
"task_api_agentic_complete.jsonl",
"distillation_dataset.jsonl",
ratio=0.3
)
Output:
Mixed dataset: 600 existing + 257 distillation = 857 total
Training Configuration for Distillation
# Adjusted training for distillation (longer sequences, more epochs)
training_args = TrainingArguments(
output_dir="./task_api_distilled",
num_train_epochs=5, # More epochs for reasoning patterns
per_device_train_batch_size=1, # Longer sequences need smaller batch
gradient_accumulation_steps=8,
learning_rate=1e-5, # Slightly lower for complex patterns
warmup_ratio=0.15,
max_grad_norm=0.5, # Careful with gradient on reasoning
logging_steps=25,
save_steps=100,
eval_strategy="steps",
eval_steps=100,
fp16=True,
optim="adamw_8bit",
)
Evaluating Reasoning Transfer
Design Reasoning Benchmark
REASONING_BENCHMARK = [
{
"id": "prio_001",
"category": "prioritization",
"input": "I have 3 tasks due today: server migration (4h), bug fix (30m), doc update (1h). My morning is blocked with meetings until 11 AM. It's 8 AM now. What should I do?",
"expected_reasoning_elements": [
"recognizes morning block",
"calculates available time",
"sequences by duration/urgency",
"provides actionable recommendation"
],
"expected_conclusion": "server_migration_first_or_defer"
},
{
"id": "sched_001",
"category": "scheduling",
"input": "Can I fit a 2-hour deep work session today? My calendar: standup 9-9:30, lunch 12-1, review 3-4.",
"expected_reasoning_elements": [
"identifies gaps",
"calculates gap durations",
"evaluates 2h fit"
],
"expected_conclusion": "9:30-11:30 or 1-3 available"
},
# ... 50+ benchmark cases
]
def evaluate_reasoning(model, benchmark: list) -> dict:
"""Evaluate model reasoning against teacher benchmark."""
results = {
"total": len(benchmark),
"reasoning_coverage": 0,
"conclusion_correct": 0,
"by_category": {}
}
for case in benchmark:
# Get model response
response = generate_response(model, case["input"])
# Check reasoning elements
elements_found = sum(
1 for elem in case["expected_reasoning_elements"]
if element_present(response, elem)
)
coverage = elements_found / len(case["expected_reasoning_elements"])
results["reasoning_coverage"] += coverage
# Check conclusion
if conclusion_matches(response, case["expected_conclusion"]):
results["conclusion_correct"] += 1
# Category stats
cat = case["category"]
if cat not in results["by_category"]:
results["by_category"][cat] = {"count": 0, "coverage": 0, "correct": 0}
results["by_category"][cat]["count"] += 1
results["by_category"][cat]["coverage"] += coverage
if conclusion_matches(response, case["expected_conclusion"]):
results["by_category"][cat]["correct"] += 1
# Normalize
results["reasoning_coverage"] /= results["total"]
results["conclusion_correct"] /= results["total"]
for cat in results["by_category"]:
cat_data = results["by_category"][cat]
cat_data["coverage"] /= cat_data["count"]
cat_data["correct"] /= cat_data["count"]
return results
Compare Student to Teacher
def compare_student_teacher(
student_model,
teacher_api: str, # "gpt-4o"
benchmark: list
) -> dict:
"""Compare student performance to teacher baseline."""
print("Evaluating teacher (GPT-4)...")
teacher_results = evaluate_reasoning_with_api(teacher_api, benchmark)
print("Evaluating student (distilled 3B)...")
student_results = evaluate_reasoning(student_model, benchmark)
comparison = {
"teacher": teacher_results,
"student": student_results,
"gap": {
"reasoning_coverage": teacher_results["reasoning_coverage"] - student_results["reasoning_coverage"],
"conclusion_correct": teacher_results["conclusion_correct"] - student_results["conclusion_correct"]
}
}
print("\n=== DISTILLATION EVALUATION ===")
print(f"Reasoning Coverage: Teacher={teacher_results['reasoning_coverage']:.2%}, "
f"Student={student_results['reasoning_coverage']:.2%}, "
f"Gap={comparison['gap']['reasoning_coverage']:.2%}")
print(f"Conclusion Correct: Teacher={teacher_results['conclusion_correct']:.2%}, "
f"Student={student_results['conclusion_correct']:.2%}, "
f"Gap={comparison['gap']['conclusion_correct']:.2%}")
return comparison
results = compare_student_teacher(distilled_model, "gpt-4o", REASONING_BENCHMARK)
Output:
Evaluating teacher (GPT-4)...
Evaluating student (distilled 3B)...
=== DISTILLATION EVALUATION ===
Reasoning Coverage: Teacher=94.2%, Student=78.5%, Gap=15.7%
Conclusion Correct: Teacher=91.0%, Student=73.2%, Gap=17.8%
The distilled student captures ~80% of teacher reasoning—substantial improvement over the non-distilled baseline.
Advanced: GRPO for Reasoning Distillation
Emerging Technique: Group Relative Policy Optimization
GRPO (Group Relative Policy Optimization) is an emerging technique that improves reasoning distillation by:
- Generating multiple reasoning traces per problem
- Ranking traces by correctness
- Training on the ranking signal (not just imitation)
# Conceptual GRPO workflow (simplified)
def grpo_distillation_step(problem: str, teacher: str, student_model):
"""One GRPO distillation step."""
# 1. Generate N reasoning traces from student
student_traces = [
generate_response(student_model, problem)
for _ in range(8)
]
# 2. Score traces using teacher as judge
scores = [
evaluate_with_teacher(teacher, problem, trace)
for trace in student_traces
]
# 3. Create preference pairs from scored traces
pairs = create_preference_pairs(student_traces, scores)
# 4. Update student using DPO/GRPO objective
loss = compute_grpo_loss(student_model, pairs)
return loss
This is beyond the scope of hands-on implementation here, but represents the frontier of reasoning distillation research as of 2025.
Reflect on Your Skill
Update your model-merging skill to include distillation:
- Add "Reasoning Enhancement": When to distill vs when to merge
- Add distillation pipeline: Data generation → training → evaluation
- Add quality thresholds: What gap is acceptable for production use
- Add cost analysis: Distillation data generation cost vs inference savings
Try With AI
Prompt 1: Improve Distillation Data
My distilled model reaches 78% reasoning coverage vs teacher's 94%.
I want to close this gap. Here's a sample of my distillation data:
[paste 2-3 examples]
Analyze my examples:
1. Is the reasoning trace detailed enough?
2. Am I capturing the right reasoning patterns?
3. What additional examples or patterns would help?
4. Should I adjust my teacher prompt?
What you're learning: Data quality diagnosis—understanding why distillation might underperform.
Prompt 2: Combine Distillation with Merging
I now have three adapters:
1. Persona adapter (TaskMaster voice)
2. Agentic adapter (tool-calling)
3. Distilled reasoning adapter (prioritization logic)
How should I combine these? Consider:
1. Should I merge all three with TIES?
2. Or train a single model on all three datasets?
3. Any risk of interference between reasoning and tool-calling?
Walk me through the tradeoffs and recommend an approach.
What you're learning: Capability composition—combining multiple enhancement techniques.
Prompt 3: Measure Cost-Benefit
My distillation project:
- 200 examples × $0.01/call = $2 data generation
- Training: 2 GPU hours = $4
- Total: $6
Result: 78% teacher-level reasoning on 3B model
Help me analyze:
1. At what inference volume does this investment pay off?
2. If I spend $60 (10x), can I expect 90% coverage?
3. What's the diminishing returns curve for distillation data?
What you're learning: ROI analysis—treating distillation as a quantifiable investment.
Safety Note
Distilled models inherit both the capabilities and potential biases of teacher models. If GPT-4 exhibits certain reasoning patterns or blind spots, the distilled student may replicate them. Evaluate distilled models for bias amplification, especially in prioritization tasks where subtle preferences could lead to unfair outcomes. Never assume teacher quality equals distilled quality—validate independently.