Tutorial 1 — MNIST → Noisy MNIST: Source-Only vs MMD vs DANN¶
End-to-end walkthrough of the built-in benchmark: adapting a CNN from clean MNIST digits (source) to the same digits corrupted with Gaussian noise (target).
The complete runnable script lives at examples/mnist_mmd.py.
Run the script¶
The script is configured via a CONFIG block at the top:
MODEL_TYPE = "cnn" # "cnn" or "mlp"
LATENT_DIM = 128
EPOCHS = 10
BATCH_SIZE = 128
LR = 1e-3
MMD_WEIGHT = 1.0 # λ
DOMAIN_WEIGHT = 1.0 # DANN domain loss weight
NOISE_STD = 0.3 # Gaussian noise std on target
Step-by-step breakdown¶
Step 1 — Load data¶
from shiftkit.data import DataManager
dm = DataManager(root="./data", batch_size=128, num_workers=0)
train_src, train_tgt = dm.load("mnist_noisy_mnist", train=True, noise_std=0.3)
test_src, test_tgt = dm.load("mnist_noisy_mnist", train=False, noise_std=0.3)
DataManager returns paired (source_loader, target_loader) DataLoaders. The
target domain is MNIST with additive Gaussian noise; class labels are identical.
Step 2 — Build independent models¶
Both use the same architecture so the comparison is fair — different random initialisations ensure they train independently.
from shiftkit.models import CNN
model_baseline = CNN(latent_dim=128, num_classes=10, dropout=0.3)
model_mmd = CNN(latent_dim=128, num_classes=10, dropout=0.3)
model_dann = CNN(latent_dim=128, num_classes=10, dropout=0.3)
Step 3 — Train¶
from shiftkit.methods import SourceOnlyTrainer, MMDTrainer, DANNTrainer
baseline = SourceOnlyTrainer(model_baseline, train_src, train_tgt, lr=1e-3)
mmd = MMDTrainer(model_mmd, train_src, train_tgt, mmd_weight=1.0, lr=1e-3)
dann = DANNTrainer(model_dann, train_src, train_tgt, domain_weight=1.0, lr=1e-3)
history_baseline = baseline.fit(epochs=10)
history_mmd = mmd.fit(epochs=10)
history_dann = dann.fit(epochs=10)
Step 4 — Evaluate¶
Domain Source-Only MMD DANN
---------------------------------------------------------
source-train 99.44% 99.58% 99.51%
source-test 98.89% 99.14% 98.97%
target-test 95.27% 96.80% 97.10%
Step 5 — Plot training history¶
from shiftkit.diagnostics import plot_training_history
plot_training_history(
histories={
"Source Only": history_baseline,
"MMD": history_mmd,
"DANN": history_dann,
},
save_path="outputs/training_history.png",
)
The left panel shows total loss per epoch; the right panel overlays source (solid) and target (dashed) accuracy — the gap between them quantifies remaining domain shift.
Step 6 — Plot latent spaces¶
from shiftkit.diagnostics import compare_latent_spaces
compare_latent_spaces(
models={
"Source Only": model_baseline,
"MMD": model_mmd,
"DANN": model_dann,
},
source_loader=test_src,
target_loader=test_tgt,
max_samples=2000,
save_path="outputs/latent_space_comparison.png",
)
Each panel shows a t-SNE projection coloured by class (filled = source, open = target). Better alignment means source and target clusters overlap.
What to look for¶
- Source-Only: source clusters (filled) are tight; target clusters (open) are shifted.
- MMD: global distribution matching reduces the gap.
- DANN: adversarial training typically produces even tighter alignment.
- Target accuracy improvement over Source-Only quantifies the DA benefit.