Skip to content

Tutorial 5 — Tuning the DA weight λ

The mmd_weight / lmmd_weight / domain_weight parameter λ controls the trade-off between source classification accuracy and domain alignment.


Quick λ sweep

from shiftkit.models  import CNN
from shiftkit.methods import MMDTrainer

results = {}
for lam in [0.1, 0.5, 1.0, 2.0, 5.0]:
    model   = CNN(latent_dim=128, num_classes=10)
    trainer = MMDTrainer(model, train_src, train_tgt,
                         mmd_weight=lam, lr=1e-3)
    trainer.fit(epochs=10)
    stats = trainer.evaluate(test_tgt, domain="target-test")
    results[lam] = stats["accuracy"]
    print(f"λ={lam:<5}  target acc={stats['accuracy']*100:.2f}%")

Typical output:

λ=0.1   target acc=95.41%
λ=0.5   target acc=96.83%
λ=1.0   target acc=97.20%
λ=2.0   target acc=96.95%
λ=5.0   target acc=94.10%

Interpreting the results

Symptom Likely cause Remedy
Source accuracy drops, target accuracy low λ too large — DA loss overwhelming classifier Reduce λ, or increase warmup_epochs
Source and target latent spaces remain separated λ too small Increase λ
Both accuracies good on source, poor on target Fundamental distribution mismatch; λ tuning helps only if distributions are alignable Try DANN or LMMD

Using TrainerRegistry for sweeps

Register all methods once, then loop by name:

from shiftkit.models   import CNN
from shiftkit.methods  import TrainerRegistry

sweep = [
    ("mmd",  {"mmd_weight": 0.5}),
    ("mmd",  {"mmd_weight": 1.0}),
    ("lmmd", {"num_classes": 10, "lmmd_weight": 1.0}),
    ("dann", {"domain_weight": 1.0}),
]

for name, kwargs in sweep:
    model   = CNN(latent_dim=128, num_classes=10)
    trainer = TrainerRegistry.create(
        name,
        model=model,
        source_loader=train_src,
        target_loader=train_tgt,
        **kwargs,
    )
    history = trainer.fit(epochs=10)
    result  = trainer.evaluate(test_tgt, domain="target-test")
    print(f"{name} {kwargs}{result['accuracy']*100:.2f}%")

Warmup + λ interaction

Using warmup_epochs before DA begins means the model has a head start on source classification. This typically allows a slightly larger λ to be used without source accuracy collapsing:

trainer = MMDTrainer(
    model=model, source_loader=train_src, target_loader=train_tgt,
    mmd_weight=2.0,      # higher λ is safe because warmup stabilises source acc
    warmup_epochs=5,
    lr=1e-3,
)
trainer.fit(epochs=20)

Tip

Start at λ=1.0 with warmup_epochs=5. Only tune λ after confirming that source accuracy remains stable.