SIDDATrainer¶
SIDDA (SInkhorn Dynamic Domain Adaptation) trains a model using an optimal-transport domain alignment loss combined with learnable, dynamically-weighted loss terms. Two key ideas distinguish it from MMD and DANN:
-
Sinkhorn divergence as the DA loss — a debiased, entropy-regularised optimal transport distance that interpolates between MMD and the Wasserstein distance.
-
Automatic loss balancing — two scalar parameters η₁ (CE weight) and η₂ (DA weight) are jointly learned with the model, eliminating the need to hand-tune λ.
where \(\mathcal{S}_\sigma\) is the Sinkhorn divergence:
The log term regularises η, preventing collapse to zero or unbounded growth.
encoder(x) ──► z_src ──► classify(z_src) ──► CE loss ────────┐
encoder(x) ──► z_tgt ──► ├─► ℒ (weighted)
Sinkhorn_σ(z_src, z_tgt) ──► DA loss ─────────┘
↑ σ recomputed each batch from feature distances
Reference: Ciprijanovic, A., Lewis, A., Pedro, K., Downey, E., Nord, B., & Stark, A. (2025). SIDDA: SInkhorn Dynamic Domain Adaptation for Image Classification with Equivariant Neural Networks. Mach. Learn.: Sci. Technol., 6, 035032. [Paper]
Usage¶
from shiftkit.methods import SIDDATrainer
trainer = SIDDATrainer(
model=model,
source_loader=train_src,
target_loader=train_tgt,
lr=1e-2,
warmup_epochs=10, # source-only pre-training before DA begins
)
history = trainer.fit(epochs=50)
result = trainer.evaluate(test_tgt, domain="target-test")
print(f"Target accuracy: {result['accuracy']*100:.1f}%")
Constructor¶
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module |
— | Network with .encode(), .classify(), and .latent_dim |
source_loader |
DataLoader |
— | Labelled source DataLoader |
target_loader |
DataLoader |
— | Target DataLoader (labels used for tgt_acc tracking only) |
lr |
float |
1e-2 |
AdamW learning rate (model + η parameters) |
weight_decay |
float |
1e-3 |
AdamW weight decay |
warmup_epochs |
int |
0 |
Epochs of source-only pre-training before DA begins |
sigma_scale |
float |
0.05 |
Scale factor for dynamic blur: σ = max(scale · max‖z‖, floor) |
sigma_floor |
float |
0.01 |
Minimum blur value (prevents degenerate OT plans) |
grad_clip |
float |
10.0 |
Gradient clipping max-norm |
device |
str \| None |
None |
'cuda', 'mps', or 'cpu'; auto-detected if None |
fit(epochs=50)¶
Train for epochs epochs and return the history.
Returns: list[dict] — one dict per epoch:
| Key | Description |
|---|---|
epoch |
Epoch number (1-indexed) |
ce_loss |
Mean cross-entropy loss |
da_loss |
Mean Sinkhorn divergence loss (0.0 during warmup) |
mmd_loss |
Always 0.0 (for history-format compatibility) |
total_loss |
Mean total weighted loss |
src_acc |
Source domain training accuracy |
tgt_acc |
Target domain accuracy (tracked, not directly optimised) |
eta1 |
Current η₁ (CE weight) at epoch end |
eta2 |
Current η₂ (DA weight) at epoch end |
sigma |
Mean Sinkhorn blur σ used this epoch (0.0 during warmup) |
evaluate(loader, domain="source")¶
Compute classification accuracy on any labelled DataLoader.
Returns: dict with keys domain (str), accuracy (float), n_samples (int).
Warmup phase¶
During the warmup phase the encoder trains on source classification only — no DA loss, no η updates. This ensures the latent space carries class-discriminative structure before Sinkhorn alignment begins.
trainer = SIDDATrainer(
model=model,
source_loader=train_src,
target_loader=train_tgt,
warmup_epochs=10, # first 10 epochs: source-only CE
lr=1e-2,
)
history = trainer.fit(epochs=50)
# Epochs 1–10: [warmup] CE only
# Epochs 11–50: [SIDDA ] CE + Sinkhorn DA + learnable η
Choosing warmup length
The paper found that equivariant networks require a shorter warmup than standard CNNs (e.g. 5–10 vs 10–30 epochs). As a rule of thumb, warmup should be long enough that source accuracy is reasonable but short enough to avoid overfitting the source domain before adaptation starts.
Dynamic Sinkhorn blur¶
The Sinkhorn regularisation strength σ is adapted each batch:
Layer normalisation is applied to the latent features before computing pairwise distances, preventing a small number of outlier features from inflating σ. As training progresses and source/target distributions align, σ shrinks naturally, giving a progressively sharper (more exact) OT plan.
Learnable loss weights η¶
η₁ and η₂ are scalar nn.Parameter values optimised jointly with the model.
Hard constraints are enforced after every gradient step:
- η₁ ≥ 1e-3 (prevents CE from being silenced)
- η₂ ≥ 0.25 · η₁ (prevents DA from completely dominating)
You can inspect how the balance evolves during training: