LMMDTrainer¶
LMMD (Local Maximum Mean Discrepancy) extends global MMD by aligning
per-class subdomains instead of the overall marginal distributions.
Whereas standard MMD matches p(z_src) to p(z_tgt) globally, LMMD matches
p(z_src | y=c) to p(z_tgt | y=c) for every class c — preventing
semantically different classes from being forced together during alignment.
Class-conditional weights:
- Source weights are derived from one-hot ground-truth labels
y_src. - Target weights are derived from the model's current softmax predictions (treated as soft pseudo-labels; stop-gradient).
Reference: Zhu, Y., Zhuang, F., & Wang, D. (2020). Deep Subdomain Adaptation Network for Image Classification. IEEE Transactions on Neural Networks and Learning Systems, 32(4), 1713–1722. [Paper]
Usage¶
from shiftkit.methods import LMMDTrainer
trainer = LMMDTrainer(
model=model,
source_loader=train_src,
target_loader=train_tgt,
num_classes=10,
lmmd_weight=1.0,
lr=1e-3,
)
history = trainer.fit(epochs=10)
result = trainer.evaluate(test_tgt, domain="target-test")
print(f"Target accuracy: {result['accuracy']*100:.1f}%")
Constructor¶
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module |
— | Network with .encode() and .classify() methods |
source_loader |
DataLoader |
— | Labelled source DataLoader |
target_loader |
DataLoader |
— | Target DataLoader (labels used for tgt_acc tracking only) |
num_classes |
int |
— | Number of output classes C |
lmmd_weight |
float |
1.0 |
λ — weight on the LMMD regularisation term |
lr |
float |
1e-3 |
Adam learning rate |
warmup_epochs |
int |
0 |
Epochs of source-only CE pre-training before LMMD begins |
device |
str \| None |
None |
'cuda', 'mps', or 'cpu'; auto-detected if None |
lmmd_sigmas |
list[float] \| None |
None |
RBF kernel bandwidths; defaults to [0.1, 1, 5, 10, 50] |
fit(epochs=10)¶
Train for epochs epochs and return the history.
Returns: list[dict] — one dict per epoch:
| Key | Description |
|---|---|
epoch |
Epoch number (1-indexed) |
ce_loss |
Mean cross-entropy loss |
lmmd_loss |
Mean LMMD loss (0.0 during warmup) |
mmd_loss |
Always 0.0 (history-format compatibility) |
total_loss |
Mean total loss (CE + λ·LMMD; CE only during warmup) |
src_acc |
Source domain training accuracy |
tgt_acc |
Target domain accuracy (tracked, not optimised directly) |
evaluate(loader, domain="source")¶
Compute classification accuracy on any labelled DataLoader.
Returns: dict with keys domain (str), accuracy (float), n_samples (int).
MMD vs LMMD¶
| MMD | LMMD | |
|---|---|---|
| What is aligned | Global marginals p(z) |
Per-class conditionals p(z \| y=c) |
| Source supervision | Not used in DA loss | One-hot labels weight the kernel |
| Target labels | Not needed | Soft pseudo-labels from model predictions |
| Risk | Negative transfer if class structure differs | Requires reasonable pseudo-labels |
num_classes parameter |
Not required | Required |
LMMD typically outperforms global MMD when the source and target class structures are well-separated, because it avoids matching source class A features to target class B features.
LMMDLoss¶
The raw LMMD module, exposed for use in custom training loops.
from shiftkit.methods import LMMDLoss
lmmd = LMMDLoss(num_classes=10, sigmas=[0.1, 1.0, 5.0, 10.0, 50.0])
# y_tgt_prob: softmax predictions for target batch, shape (n_t, C)
with torch.no_grad():
y_tgt_prob = torch.softmax(model.classify(model.encode(x_tgt)), dim=1)
loss = lmmd(z_src, z_tgt, y_src, y_tgt_prob) # scalar tensor
Constructor¶
| Parameter | Type | Default | Description |
|---|---|---|---|
num_classes |
int |
— | Number of classes C |
sigmas |
list[float] \| None |
None |
Kernel bandwidths; defaults to [0.1, 1.0, 5.0, 10.0, 50.0] |
forward(z_src, z_tgt, y_src, y_tgt_prob)¶
| Parameter | Type | Description |
|---|---|---|
z_src |
Tensor (n_s, d) |
Source latent vectors |
z_tgt |
Tensor (n_t, d) |
Target latent vectors |
y_src |
Tensor (n_s,) |
Source ground-truth labels (long) |
y_tgt_prob |
Tensor (n_t, C) |
Target softmax probabilities |
Returns: Scalar LMMD loss averaged over C classes.