Sharded Merging for RAM-Constrained Systems
Your laptop has 16GB RAM. You want to merge two 8B-parameter adapters. Naive merging attempts to load both models plus the output—40+ GB of memory. Your system crashes.
This lesson teaches sharded merging: processing models layer-by-layer with lazy loading, keeping memory usage constant regardless of model size. You'll merge models that shouldn't fit in your RAM.
The Memory Problem
Why Naive Merging Fails
Standard merging loads entire models into memory:
Standard Merge Memory Usage:
┌──────────────────────────────────────────────────────────────────┐
│ │
│ Model A (full): 16 GB │
│ Model B (full): 16 GB │
│ Output buffer: 8 GB (growing) │
│ Working memory: 4 GB │
│ ───────────────────────── │
│ TOTAL: 44 GB ❌ Exceeds 16GB RAM │
│ │
└──────────────────────────────────────────────────────────────────┘
Even with swap, this is impractical—swapping multi-GB tensors causes minutes-long pauses.
Memory Profile of Naive Merge
import torch
import tracemalloc
def profile_naive_merge(model_a_path: str, model_b_path: str):
"""Profile memory usage during naive merge."""
tracemalloc.start()
# Load Model A - PEAK 1
print("Loading Model A...")
model_a = torch.load(f"{model_a_path}/pytorch_model.bin")
current, peak = tracemalloc.get_traced_memory()
print(f" After Model A: {peak / 1e9:.1f} GB peak")
# Load Model B - PEAK 2
print("Loading Model B...")
model_b = torch.load(f"{model_b_path}/pytorch_model.bin")
current, peak = tracemalloc.get_traced_memory()
print(f" After Model B: {peak / 1e9:.1f} GB peak")
# Merge - PEAK 3
print("Merging...")
merged = {}
for key in model_a.keys():
merged[key] = 0.5 * model_a[key] + 0.5 * model_b[key]
current, peak = tracemalloc.get_traced_memory()
print(f" After Merge: {peak / 1e9:.1f} GB peak")
tracemalloc.stop()
return merged
Output:
Loading Model A...
After Model A: 16.2 GB peak
Loading Model B...
After Model B: 32.4 GB peak
Merging...
After Merge: 44.1 GB peak ❌ OOM on 16GB system
The Solution: Sharded Merging
Layer-by-Layer Processing
Instead of loading entire models, process one layer at a time:
Sharded Merge Memory Usage:
┌──────────────────────────────────────────────────────────────────┐
│ │
│ Layer N from Model A: 0.5 GB │
│ Layer N from Model B: 0.5 GB │
│ Layer N merged: 0.5 GB │
│ Working memory: 0.5 GB │
│ ───────────────────────── │
│ TOTAL: 2 GB ✅ Fits in 16GB with room │
│ │
└──────────────────────────────────────────────────────────────────┘
The key insight: each layer can be processed independently, then immediately written to disk.
Memory Comparison
| Approach | Peak RAM | 8B Model Feasible on 16GB? |
|---|---|---|
| Naive | 40+ GB | No |
| Sharded (layer-by-layer) | 2-4 GB | Yes |
| Lazy + Sharded | 1-2 GB | Yes (with margin) |
MergeKit Lazy Loading
MergeKit supports lazy loading through the --lazy flag:
Basic Lazy Merge
mergekit-yaml config.yaml ./output --lazy --low-cpu-mem
Flags:
--lazy: Load model weights on-demand, not upfront--low-cpu-mem: Minimize CPU memory usage--no-cuda: Force CPU processing (if GPU RAM also limited)
Lazy Merge Configuration
# lazy_merge.yaml
merge_method: ties
slices:
- sources:
- model: ./adapters/task_api_persona
layer_range: [0, 28]
- model: ./adapters/task_api_agentic
layer_range: [0, 28]
parameters:
weight: 0.5
density: 0.5
base_model: unsloth/Llama-3.2-3B-Instruct
dtype: float16
# Memory optimization
tokenizer_source: base
out_shard_size: 2000000000 # 2GB shards for output
Execute with Memory Monitoring
# Monitor memory during merge
/usr/bin/time -v mergekit-yaml lazy_merge.yaml ./merged_lazy \
--lazy --low-cpu-mem 2>&1 | tee merge_log.txt
# Check peak memory
grep "Maximum resident set size" merge_log.txt
Output:
Loading config...
Processing layer 0/28 (lazy)
Loaded: model.layers.0.self_attn.q_proj (32MB)
Loaded: model.layers.0.self_attn.k_proj (32MB)
Merging...
Written to shard
Freed memory
Processing layer 1/28 (lazy)
...
Processing layer 27/28 (lazy)
Merge complete.
Maximum resident set size (kbytes): 4523648 # ~4.5 GB peak ✅
Implementing Custom Sharded Merge
For maximum control, implement layer-by-layer merging yourself:
Sharded Merge Script
import torch
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path
import gc
def sharded_ties_merge(
base_model: str,
adapters: list[str],
output_dir: str,
density: float = 0.5,
max_shard_gb: float = 2.0
) -> None:
"""Memory-efficient TIES merge processing one layer at a time."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Get layer list from first adapter
layer_names = get_layer_names(adapters[0])
print(f"Processing {len(layer_names)} layers")
current_shard = {}
current_shard_size = 0
shard_index = 0
max_shard_bytes = max_shard_gb * 1e9
for layer_name in layer_names:
print(f"Processing: {layer_name}")
# Load this layer from all adapters (one at a time)
layer_weights = []
for adapter_path in adapters:
weight = load_single_layer(adapter_path, layer_name)
layer_weights.append(weight)
# Apply TIES merge for this layer
merged_weight = ties_merge_layer(layer_weights, density)
# Add to current shard
layer_size = merged_weight.numel() * merged_weight.element_size()
current_shard[layer_name] = merged_weight
current_shard_size += layer_size
# Write shard if exceeds limit
if current_shard_size >= max_shard_bytes:
shard_file = output_path / f"model-{shard_index:05d}-of-XXXXX.safetensors"
save_file(current_shard, str(shard_file))
print(f" Wrote shard {shard_index}: {current_shard_size / 1e9:.2f} GB")
current_shard = {}
current_shard_size = 0
shard_index += 1
gc.collect() # Free memory
# Clear layer weights from memory
del layer_weights
gc.collect()
# Write final shard
if current_shard:
shard_file = output_path / f"model-{shard_index:05d}-of-XXXXX.safetensors"
save_file(current_shard, str(shard_file))
print(f" Wrote final shard {shard_index}: {current_shard_size / 1e9:.2f} GB")
# Update shard count in filenames
update_shard_names(output_path, shard_index + 1)
print(f"Merge complete: {shard_index + 1} shards in {output_dir}")
def load_single_layer(model_path: str, layer_name: str) -> torch.Tensor:
"""Load only one layer from a safetensors file."""
# Find which shard contains this layer
for shard_file in Path(model_path).glob("*.safetensors"):
with safe_open(str(shard_file), framework="pt", device="cpu") as f:
if layer_name in f.keys():
return f.get_tensor(layer_name)
raise KeyError(f"Layer {layer_name} not found in {model_path}")
def ties_merge_layer(weights: list[torch.Tensor], density: float) -> torch.Tensor:
"""Apply TIES merge to a single layer."""
# Stack weights
stacked = torch.stack(weights)
# Trim: keep top density% by magnitude
flat = stacked.view(len(weights), -1)
magnitudes = flat.abs()
threshold = torch.quantile(magnitudes, 1 - density)
mask = magnitudes >= threshold
trimmed = flat * mask
# Elect signs
sign_sum = trimmed.sign().sum(dim=0)
elected_sign = sign_sum.sign()
# Average values matching elected sign
aligned = trimmed * (trimmed.sign() == elected_sign)
counts = (aligned != 0).sum(dim=0).clamp(min=1)
merged = aligned.sum(dim=0) / counts
return merged.view(weights[0].shape)
Memory-Monitored Execution
import psutil
import os
def run_with_monitoring(merge_fn, *args, **kwargs):
"""Run merge function with continuous memory monitoring."""
process = psutil.Process(os.getpid())
peak_memory = 0
initial_memory = process.memory_info().rss
def update_peak():
nonlocal peak_memory
current = process.memory_info().rss
peak_memory = max(peak_memory, current)
return current
# Run merge
result = merge_fn(*args, **kwargs)
final_memory = update_peak()
print(f"\nMemory Report:")
print(f" Initial: {initial_memory / 1e9:.2f} GB")
print(f" Peak: {peak_memory / 1e9:.2f} GB")
print(f" Final: {final_memory / 1e9:.2f} GB")
print(f" Delta: {(peak_memory - initial_memory) / 1e9:.2f} GB")
return result
# Run sharded merge with monitoring
run_with_monitoring(
sharded_ties_merge,
base_model="unsloth/Llama-3.2-3B-Instruct",
adapters=["./adapters/persona", "./adapters/agentic"],
output_dir="./merged_sharded",
density=0.5,
max_shard_gb=2.0
)
Output:
Processing 28 layers
Processing: model.layers.0.self_attn.q_proj
Processing: model.layers.0.self_attn.k_proj
Wrote shard 0: 2.01 GB
Processing: model.layers.0.self_attn.v_proj
...
Processing: model.layers.27.mlp.down_proj
Wrote final shard 7: 1.45 GB
Merge complete: 8 shards in ./merged_sharded
Memory Report:
Initial: 0.42 GB
Peak: 3.21 GB
Delta: 2.79 GB ✅ Well under 16GB limit
Tuning for Your Hardware
Finding Optimal Shard Size
def find_optimal_shard_size(
available_ram_gb: float,
model_layers: int,
layer_size_gb: float,
num_adapters: int
) -> float:
"""Calculate optimal shard size for your hardware."""
# Memory needed per layer during processing:
# - Load from each adapter: num_adapters * layer_size
# - Working memory for merge: ~2x layer_size
# - Output buffer: layer_size
per_layer_memory = (num_adapters + 3) * layer_size_gb
# Leave 2GB headroom for system
available_for_merge = available_ram_gb - 2.0
# Number of layers we can buffer
max_buffered_layers = int(available_for_merge / per_layer_memory)
# Convert to shard size
optimal_shard_gb = max_buffered_layers * layer_size_gb
print(f"Hardware Analysis:")
print(f" Available RAM: {available_ram_gb} GB")
print(f" Per-layer memory: {per_layer_memory:.2f} GB")
print(f" Max buffered layers: {max_buffered_layers}")
print(f" Recommended shard size: {optimal_shard_gb:.1f} GB")
return optimal_shard_gb
# Example: 16GB RAM, 8B model, 2 adapters
optimal = find_optimal_shard_size(
available_ram_gb=16.0,
model_layers=28,
layer_size_gb=0.5, # ~500MB per layer for 8B model
num_adapters=2
)
Output:
Hardware Analysis:
Available RAM: 16.0 GB
Per-layer memory: 2.50 GB
Max buffered layers: 5
Recommended shard size: 2.5 GB
Speed vs Memory Tradeoff
| Shard Size | Peak RAM | Merge Time (8B) | Disk I/O |
|---|---|---|---|
| 0.5 GB | ~2 GB | 25 min | High |
| 2.0 GB | ~5 GB | 12 min | Medium |
| 4.0 GB | ~9 GB | 8 min | Low |
| No sharding | ~40 GB | 5 min | Minimal |
Choose based on your constraints:
- Tight RAM (12GB): Use 1-2 GB shards
- Moderate RAM (16GB): Use 2-4 GB shards
- Ample RAM (32GB+): Can skip sharding
Verification After Sharded Merge
Validate Output Integrity
def verify_sharded_output(output_dir: str, expected_layers: int) -> bool:
"""Verify sharded output is complete and valid."""
shards = list(Path(output_dir).glob("*.safetensors"))
print(f"Found {len(shards)} shards")
# Collect all layer names across shards
all_layers = set()
for shard in shards:
with safe_open(str(shard), framework="pt") as f:
all_layers.update(f.keys())
print(f"Total layers: {len(all_layers)}")
# Check for expected layers
expected_patterns = [
"model.embed_tokens",
"model.layers.0.self_attn.q_proj",
f"model.layers.{expected_layers-1}.mlp.down_proj",
"lm_head.weight",
]
for pattern in expected_patterns:
if pattern not in all_layers:
print(f"MISSING: {pattern}")
return False
print(f"Found: {pattern}")
# Verify shapes are correct
with safe_open(str(shards[0]), framework="pt") as f:
for key in list(f.keys())[:3]:
tensor = f.get_tensor(key)
print(f" {key}: {tensor.shape}")
return True
verify_sharded_output("./merged_sharded", expected_layers=28)
Output:
Found 8 shards
Total layers: 225
Found: model.embed_tokens
Found: model.layers.0.self_attn.q_proj
Found: model.layers.27.mlp.down_proj
Found: lm_head.weight
model.layers.0.self_attn.q_proj: torch.Size([2048, 2048])
model.layers.0.self_attn.k_proj: torch.Size([512, 2048])
model.layers.0.self_attn.v_proj: torch.Size([512, 2048])
Reflect on Your Skill
Update your model-merging skill with RAM-constrained guidance:
- Add hardware requirements section: Memory formula for merge feasibility
- Add sharding parameters: Shard size recommendations by RAM tier
- Add verification checklist: Post-merge integrity checks
- Add troubleshooting: Common OOM causes and fixes
Try With AI
Prompt 1: Diagnose OOM Error
I'm trying to merge two 8B adapters on my 16GB MacBook Pro. MergeKit crashes:
mergekit-yaml config.yaml ./output --lazy
...
RuntimeError: [MPS] out of memory
I thought --lazy would prevent this. What's happening and how do I fix it?
Consider:
1. Why lazy loading might still cause OOM
2. Platform-specific issues (Mac MPS vs CUDA)
3. Additional flags I should use
What you're learning: Platform-specific debugging—understanding that memory management differs across hardware.
Prompt 2: Optimize Merge Speed
My sharded merge works but takes 45 minutes for a 3B model. This seems slow.
Current setup:
- 16GB RAM Mac M2
- 1GB shard size
- Processing on CPU (--no-cuda)
Can I speed this up while staying under 16GB? Walk me through:
1. Optimal shard size for my hardware
2. Whether MPS acceleration helps
3. Parallelization opportunities
4. Any caching strategies
What you're learning: Performance optimization—finding the sweet spot between memory and speed.
Prompt 3: Scale to Larger Models
I successfully merged 3B adapters on 16GB RAM. Now I want to merge 70B adapters.
Obviously my laptop won't work. What are my options?
Consider:
1. Cloud instances (what specs do I need?)
2. Can I stream shards to/from cloud storage?
3. Distributed merging across multiple machines?
4. Any quantized merging approaches?
Help me design an approach that minimizes cost.
What you're learning: Scaling strategies—extending techniques to production-scale models.
Safety Note
Sharded merging creates many intermediate files. Ensure sufficient disk space (2-3x model size) before starting. Interrupted merges may leave partial shards that waste disk space. Always verify merge completion before deleting source files. For critical merges, maintain backups of source adapters until the merged model is validated.