Skip to content

Methods

shiftkit.methods provides domain adaptation training loops. All trainers record identical per-epoch history dicts so their results can be directly compared.


Available methods

Trainer DA mechanism Key parameter Page
SourceOnlyTrainer No adaptation (baseline)
MMDTrainer Latent distribution matching via MMD mmd_weight λ
DANNTrainer Adversarial discriminator + GRL domain_weight λ
LMMDTrainer Per-class subdomain alignment via local MMD num_classes
CORALTrainer Covariance alignment (second-order statistics) coral_weight λ
SIDDATrainer Sinkhorn optimal transport + learnable η weights warmup_epochs
KLIEPTrainer Instance-based importance weighting via density ratio estimation n_centers, weight_clip
Custom Your own method

All trainers share the same interface:

trainer = AnyTrainer(model, source_loader, target_loader, ...)
history = trainer.fit(epochs=10)
result  = trainer.evaluate(test_loader, domain="target-test")

Method comparison

Source Only MMD LMMD CORAL DANN SIDDA KLIEP
DA family Feature-based Feature-based Feature-based Feature-based Feature-based Instance-based
Alignment target None Full distribution Per-class subdomains Covariance matrix Domain labels Optimal transport plan Sample weights (density ratio)
What is matched All moments (via kernel) Class-conditional moments 2nd-order statistics Domain membership Entire marginal distribution p_tgt(x) / p_src(x)
Kernel required No Yes — RBF, bandwidth σ Yes — RBF, bandwidth σ No No No (Sinkhorn entropic OT) Yes — RBF in input space
Needs source labels Yes Yes Yes Yes Yes Yes Yes
Needs target labels No No Pseudo-labels (soft) No No No No
Adversarial training No No No No Yes (GRL) No No
Learnable loss weights No No No No No Yes (η₁, η₂) No
Alignment cost Every batch Every batch Every batch Every batch Every batch Once at init
Model interface forward() encode() + classify() encode() + classify() encode() + classify() encode() + classify() full SIDDA interface forward() only
Computation per batch O(n·d) O(n²) kernel matrices O(n²) per class O(n·d²) covariance O(n·d) + discriminator O(n²) Sinkhorn iterations O(n·m) weight lookup
Key hyperparameter mmd_weight λ lmmd_weight λ coral_weight λ domain_weight λ warmup_epochs, blur schedule n_centers, weight_clip
Warmup supported Yes Yes Yes Yes Mandatory No
Covariate shift assumption No No No No No No Yes
Best suited for Reference baseline General distribution shift Class-level shift with label imbalance Shift in feature scale / correlation Strong covariate shift with large batches Unknown shift type; automatically reweights objectives Covariate shift on tabular / low-dim data

Choosing a method: Start with the Source-Only baseline to measure the domain gap. For most tasks, MMD or CORAL is a fast, strong first attempt. Use LMMD when class distributions differ across domains. Use DANN when the shift is severe and batch sizes are large enough to train the discriminator. Use SIDDA when you want automatic loss balancing without manual λ tuning. Use KLIEP when the covariate shift assumption holds and you prefer instance reweighting over feature alignment — especially effective on tabular data.


Shared history format

Every fit() call returns a list[dict] with one entry per epoch:

Key Type Description
epoch int Epoch index (1-based)
ce_loss float Cross-entropy loss
mmd_loss float MMD² loss (0.0 if not applicable)
domain_loss float Adversarial domain loss (0.0 if not applicable)
da_loss float Sinkhorn DA loss (0.0 if not applicable)
eta1 float Learned CE weight η₁ (SIDDA only)
eta2 float Learned DA weight η₂ (SIDDA only)
sigma float Sinkhorn blur used (SIDDA only)
total_loss float Total combined loss
src_acc float Source domain accuracy
tgt_acc float Target domain accuracy (tracked, not directly optimised)