Skip to main content

Sharded Merging for RAM-Constrained Systems

Your laptop has 16GB RAM. You want to merge two 8B-parameter adapters. Naive merging attempts to load both models plus the output—40+ GB of memory. Your system crashes.

This lesson teaches sharded merging: processing models layer-by-layer with lazy loading, keeping memory usage constant regardless of model size. You'll merge models that shouldn't fit in your RAM.

The Memory Problem

Why Naive Merging Fails

Standard merging loads entire models into memory:

Standard Merge Memory Usage:
┌──────────────────────────────────────────────────────────────────┐
│ │
│ Model A (full): 16 GB │
│ Model B (full): 16 GB │
│ Output buffer: 8 GB (growing) │
│ Working memory: 4 GB │
│ ───────────────────────── │
│ TOTAL: 44 GB ❌ Exceeds 16GB RAM │
│ │
└──────────────────────────────────────────────────────────────────┘

Even with swap, this is impractical—swapping multi-GB tensors causes minutes-long pauses.

Memory Profile of Naive Merge

import torch
import tracemalloc

def profile_naive_merge(model_a_path: str, model_b_path: str):
"""Profile memory usage during naive merge."""
tracemalloc.start()

# Load Model A - PEAK 1
print("Loading Model A...")
model_a = torch.load(f"{model_a_path}/pytorch_model.bin")
current, peak = tracemalloc.get_traced_memory()
print(f" After Model A: {peak / 1e9:.1f} GB peak")

# Load Model B - PEAK 2
print("Loading Model B...")
model_b = torch.load(f"{model_b_path}/pytorch_model.bin")
current, peak = tracemalloc.get_traced_memory()
print(f" After Model B: {peak / 1e9:.1f} GB peak")

# Merge - PEAK 3
print("Merging...")
merged = {}
for key in model_a.keys():
merged[key] = 0.5 * model_a[key] + 0.5 * model_b[key]
current, peak = tracemalloc.get_traced_memory()
print(f" After Merge: {peak / 1e9:.1f} GB peak")

tracemalloc.stop()
return merged

Output:

Loading Model A...
After Model A: 16.2 GB peak
Loading Model B...
After Model B: 32.4 GB peak
Merging...
After Merge: 44.1 GB peak ❌ OOM on 16GB system

The Solution: Sharded Merging

Layer-by-Layer Processing

Instead of loading entire models, process one layer at a time:

Sharded Merge Memory Usage:
┌──────────────────────────────────────────────────────────────────┐
│ │
│ Layer N from Model A: 0.5 GB │
│ Layer N from Model B: 0.5 GB │
│ Layer N merged: 0.5 GB │
│ Working memory: 0.5 GB │
│ ───────────────────────── │
│ TOTAL: 2 GB ✅ Fits in 16GB with room │
│ │
└──────────────────────────────────────────────────────────────────┘

The key insight: each layer can be processed independently, then immediately written to disk.

Memory Comparison

ApproachPeak RAM8B Model Feasible on 16GB?
Naive40+ GBNo
Sharded (layer-by-layer)2-4 GBYes
Lazy + Sharded1-2 GBYes (with margin)

MergeKit Lazy Loading

MergeKit supports lazy loading through the --lazy flag:

Basic Lazy Merge

mergekit-yaml config.yaml ./output --lazy --low-cpu-mem

Flags:

  • --lazy: Load model weights on-demand, not upfront
  • --low-cpu-mem: Minimize CPU memory usage
  • --no-cuda: Force CPU processing (if GPU RAM also limited)

Lazy Merge Configuration

# lazy_merge.yaml
merge_method: ties
slices:
- sources:
- model: ./adapters/task_api_persona
layer_range: [0, 28]
- model: ./adapters/task_api_agentic
layer_range: [0, 28]
parameters:
weight: 0.5
density: 0.5
base_model: unsloth/Llama-3.2-3B-Instruct
dtype: float16

# Memory optimization
tokenizer_source: base
out_shard_size: 2000000000 # 2GB shards for output

Execute with Memory Monitoring

# Monitor memory during merge
/usr/bin/time -v mergekit-yaml lazy_merge.yaml ./merged_lazy \
--lazy --low-cpu-mem 2>&1 | tee merge_log.txt

# Check peak memory
grep "Maximum resident set size" merge_log.txt

Output:

Loading config...
Processing layer 0/28 (lazy)
Loaded: model.layers.0.self_attn.q_proj (32MB)
Loaded: model.layers.0.self_attn.k_proj (32MB)
Merging...
Written to shard
Freed memory
Processing layer 1/28 (lazy)
...
Processing layer 27/28 (lazy)
Merge complete.
Maximum resident set size (kbytes): 4523648 # ~4.5 GB peak ✅

Implementing Custom Sharded Merge

For maximum control, implement layer-by-layer merging yourself:

Sharded Merge Script

import torch
from safetensors import safe_open
from safetensors.torch import save_file
from pathlib import Path
import gc

def sharded_ties_merge(
base_model: str,
adapters: list[str],
output_dir: str,
density: float = 0.5,
max_shard_gb: float = 2.0
) -> None:
"""Memory-efficient TIES merge processing one layer at a time."""

output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)

# Get layer list from first adapter
layer_names = get_layer_names(adapters[0])
print(f"Processing {len(layer_names)} layers")

current_shard = {}
current_shard_size = 0
shard_index = 0
max_shard_bytes = max_shard_gb * 1e9

for layer_name in layer_names:
print(f"Processing: {layer_name}")

# Load this layer from all adapters (one at a time)
layer_weights = []
for adapter_path in adapters:
weight = load_single_layer(adapter_path, layer_name)
layer_weights.append(weight)

# Apply TIES merge for this layer
merged_weight = ties_merge_layer(layer_weights, density)

# Add to current shard
layer_size = merged_weight.numel() * merged_weight.element_size()
current_shard[layer_name] = merged_weight
current_shard_size += layer_size

# Write shard if exceeds limit
if current_shard_size >= max_shard_bytes:
shard_file = output_path / f"model-{shard_index:05d}-of-XXXXX.safetensors"
save_file(current_shard, str(shard_file))
print(f" Wrote shard {shard_index}: {current_shard_size / 1e9:.2f} GB")

current_shard = {}
current_shard_size = 0
shard_index += 1
gc.collect() # Free memory

# Clear layer weights from memory
del layer_weights
gc.collect()

# Write final shard
if current_shard:
shard_file = output_path / f"model-{shard_index:05d}-of-XXXXX.safetensors"
save_file(current_shard, str(shard_file))
print(f" Wrote final shard {shard_index}: {current_shard_size / 1e9:.2f} GB")

# Update shard count in filenames
update_shard_names(output_path, shard_index + 1)

print(f"Merge complete: {shard_index + 1} shards in {output_dir}")


def load_single_layer(model_path: str, layer_name: str) -> torch.Tensor:
"""Load only one layer from a safetensors file."""
# Find which shard contains this layer
for shard_file in Path(model_path).glob("*.safetensors"):
with safe_open(str(shard_file), framework="pt", device="cpu") as f:
if layer_name in f.keys():
return f.get_tensor(layer_name)

raise KeyError(f"Layer {layer_name} not found in {model_path}")


def ties_merge_layer(weights: list[torch.Tensor], density: float) -> torch.Tensor:
"""Apply TIES merge to a single layer."""
# Stack weights
stacked = torch.stack(weights)

# Trim: keep top density% by magnitude
flat = stacked.view(len(weights), -1)
magnitudes = flat.abs()
threshold = torch.quantile(magnitudes, 1 - density)
mask = magnitudes >= threshold
trimmed = flat * mask

# Elect signs
sign_sum = trimmed.sign().sum(dim=0)
elected_sign = sign_sum.sign()

# Average values matching elected sign
aligned = trimmed * (trimmed.sign() == elected_sign)
counts = (aligned != 0).sum(dim=0).clamp(min=1)
merged = aligned.sum(dim=0) / counts

return merged.view(weights[0].shape)

Memory-Monitored Execution

import psutil
import os

def run_with_monitoring(merge_fn, *args, **kwargs):
"""Run merge function with continuous memory monitoring."""
process = psutil.Process(os.getpid())

peak_memory = 0
initial_memory = process.memory_info().rss

def update_peak():
nonlocal peak_memory
current = process.memory_info().rss
peak_memory = max(peak_memory, current)
return current

# Run merge
result = merge_fn(*args, **kwargs)

final_memory = update_peak()

print(f"\nMemory Report:")
print(f" Initial: {initial_memory / 1e9:.2f} GB")
print(f" Peak: {peak_memory / 1e9:.2f} GB")
print(f" Final: {final_memory / 1e9:.2f} GB")
print(f" Delta: {(peak_memory - initial_memory) / 1e9:.2f} GB")

return result

# Run sharded merge with monitoring
run_with_monitoring(
sharded_ties_merge,
base_model="unsloth/Llama-3.2-3B-Instruct",
adapters=["./adapters/persona", "./adapters/agentic"],
output_dir="./merged_sharded",
density=0.5,
max_shard_gb=2.0
)

Output:

Processing 28 layers
Processing: model.layers.0.self_attn.q_proj
Processing: model.layers.0.self_attn.k_proj
Wrote shard 0: 2.01 GB
Processing: model.layers.0.self_attn.v_proj
...
Processing: model.layers.27.mlp.down_proj
Wrote final shard 7: 1.45 GB
Merge complete: 8 shards in ./merged_sharded

Memory Report:
Initial: 0.42 GB
Peak: 3.21 GB
Delta: 2.79 GB ✅ Well under 16GB limit

Tuning for Your Hardware

Finding Optimal Shard Size

def find_optimal_shard_size(
available_ram_gb: float,
model_layers: int,
layer_size_gb: float,
num_adapters: int
) -> float:
"""Calculate optimal shard size for your hardware."""

# Memory needed per layer during processing:
# - Load from each adapter: num_adapters * layer_size
# - Working memory for merge: ~2x layer_size
# - Output buffer: layer_size
per_layer_memory = (num_adapters + 3) * layer_size_gb

# Leave 2GB headroom for system
available_for_merge = available_ram_gb - 2.0

# Number of layers we can buffer
max_buffered_layers = int(available_for_merge / per_layer_memory)

# Convert to shard size
optimal_shard_gb = max_buffered_layers * layer_size_gb

print(f"Hardware Analysis:")
print(f" Available RAM: {available_ram_gb} GB")
print(f" Per-layer memory: {per_layer_memory:.2f} GB")
print(f" Max buffered layers: {max_buffered_layers}")
print(f" Recommended shard size: {optimal_shard_gb:.1f} GB")

return optimal_shard_gb

# Example: 16GB RAM, 8B model, 2 adapters
optimal = find_optimal_shard_size(
available_ram_gb=16.0,
model_layers=28,
layer_size_gb=0.5, # ~500MB per layer for 8B model
num_adapters=2
)

Output:

Hardware Analysis:
Available RAM: 16.0 GB
Per-layer memory: 2.50 GB
Max buffered layers: 5
Recommended shard size: 2.5 GB

Speed vs Memory Tradeoff

Shard SizePeak RAMMerge Time (8B)Disk I/O
0.5 GB~2 GB25 minHigh
2.0 GB~5 GB12 minMedium
4.0 GB~9 GB8 minLow
No sharding~40 GB5 minMinimal

Choose based on your constraints:

  • Tight RAM (12GB): Use 1-2 GB shards
  • Moderate RAM (16GB): Use 2-4 GB shards
  • Ample RAM (32GB+): Can skip sharding

Verification After Sharded Merge

Validate Output Integrity

def verify_sharded_output(output_dir: str, expected_layers: int) -> bool:
"""Verify sharded output is complete and valid."""

shards = list(Path(output_dir).glob("*.safetensors"))
print(f"Found {len(shards)} shards")

# Collect all layer names across shards
all_layers = set()
for shard in shards:
with safe_open(str(shard), framework="pt") as f:
all_layers.update(f.keys())

print(f"Total layers: {len(all_layers)}")

# Check for expected layers
expected_patterns = [
"model.embed_tokens",
"model.layers.0.self_attn.q_proj",
f"model.layers.{expected_layers-1}.mlp.down_proj",
"lm_head.weight",
]

for pattern in expected_patterns:
if pattern not in all_layers:
print(f"MISSING: {pattern}")
return False
print(f"Found: {pattern}")

# Verify shapes are correct
with safe_open(str(shards[0]), framework="pt") as f:
for key in list(f.keys())[:3]:
tensor = f.get_tensor(key)
print(f" {key}: {tensor.shape}")

return True

verify_sharded_output("./merged_sharded", expected_layers=28)

Output:

Found 8 shards
Total layers: 225
Found: model.embed_tokens
Found: model.layers.0.self_attn.q_proj
Found: model.layers.27.mlp.down_proj
Found: lm_head.weight
model.layers.0.self_attn.q_proj: torch.Size([2048, 2048])
model.layers.0.self_attn.k_proj: torch.Size([512, 2048])
model.layers.0.self_attn.v_proj: torch.Size([512, 2048])

Reflect on Your Skill

Update your model-merging skill with RAM-constrained guidance:

  1. Add hardware requirements section: Memory formula for merge feasibility
  2. Add sharding parameters: Shard size recommendations by RAM tier
  3. Add verification checklist: Post-merge integrity checks
  4. Add troubleshooting: Common OOM causes and fixes

Try With AI

Prompt 1: Diagnose OOM Error

I'm trying to merge two 8B adapters on my 16GB MacBook Pro. MergeKit crashes:

mergekit-yaml config.yaml ./output --lazy
...
RuntimeError: [MPS] out of memory

I thought --lazy would prevent this. What's happening and how do I fix it?
Consider:
1. Why lazy loading might still cause OOM
2. Platform-specific issues (Mac MPS vs CUDA)
3. Additional flags I should use

What you're learning: Platform-specific debugging—understanding that memory management differs across hardware.

Prompt 2: Optimize Merge Speed

My sharded merge works but takes 45 minutes for a 3B model. This seems slow.

Current setup:
- 16GB RAM Mac M2
- 1GB shard size
- Processing on CPU (--no-cuda)

Can I speed this up while staying under 16GB? Walk me through:
1. Optimal shard size for my hardware
2. Whether MPS acceleration helps
3. Parallelization opportunities
4. Any caching strategies

What you're learning: Performance optimization—finding the sweet spot between memory and speed.

Prompt 3: Scale to Larger Models

I successfully merged 3B adapters on 16GB RAM. Now I want to merge 70B adapters.
Obviously my laptop won't work. What are my options?

Consider:
1. Cloud instances (what specs do I need?)
2. Can I stream shards to/from cloud storage?
3. Distributed merging across multiple machines?
4. Any quantized merging approaches?

Help me design an approach that minimizes cost.

What you're learning: Scaling strategies—extending techniques to production-scale models.

Safety Note

Sharded merging creates many intermediate files. Ensure sufficient disk space (2-3x model size) before starting. Interrupted merges may leave partial shards that waste disk space. Always verify merge completion before deleting source files. For critical merges, maintain backups of source adapters until the merged model is validated.