Skip to content

Models

shiftkit.models provides neural network architectures with a shared encoder / classifier interface. Splitting these two components is central to how domain adaptation methods operate — the encoder produces a latent representation z, and DA losses (e.g. MMD) are computed on z directly.

All models implement:

Method Signature Description
encode (x: Tensor) → Tensor Map input to latent vector z ∈ ℝᵈ
classify (z: Tensor) → Tensor Map latent vector to class logits
forward (x: Tensor) → Tensor classify(encode(x)) — standard nn.Module interface

CNN

A small convolutional network designed for 1×28×28 inputs (MNIST-like). Two conv-pool blocks feed into a fully-connected bottleneck that produces the latent vector.

Input (1×28×28)
  → Conv2d(1→32, k=3) + BN + ReLU + MaxPool  →  32×14×14
  → Conv2d(32→64, k=3) + BN + ReLU + MaxPool  →  64×7×7
  → Flatten → Linear(3136→256) → ReLU → Dropout
  → Linear(256→latent_dim) → ReLU              →  z ∈ ℝᵈ
  → Linear(latent_dim→num_classes)              →  logits
from shiftkit.models import CNN

model = CNN(latent_dim=128, num_classes=10, dropout=0.3)

z      = model.encode(x)    # (B, 128)
logits = model.classify(z)  # (B, 10)
logits = model(x)           # equivalent

Constructor

Parameter Type Default Description
latent_dim int 128 Dimensionality of the bottleneck embedding
num_classes int 10 Number of output classes
dropout float 0.3 Dropout probability before the final FC layer

MLP

A fully-connected network that flattens the input and passes it through configurable hidden layers before the bottleneck.

Input (1×28×28 → flattened 784)
  → Linear(784→h₁) + ReLU + Dropout
  → Linear(h₁→h₂)  + ReLU + Dropout
  → ...
  → Linear(hₙ→latent_dim) + ReLU    →  z ∈ ℝᵈ
  → Linear(latent_dim→num_classes)   →  logits
from shiftkit.models import MLP

model = MLP(latent_dim=128, num_classes=10, hidden_dims=(512, 256), dropout=0.3)

z      = model.encode(x)    # (B, 128)
logits = model.classify(z)  # (B, 10)
logits = model(x)           # equivalent

Constructor

Parameter Type Default Description
latent_dim int 128 Dimensionality of the bottleneck embedding
num_classes int 10 Number of output classes
hidden_dims Tuple[int, ...] (512, 256) Sizes of hidden layers before the bottleneck
dropout float 0.3 Dropout probability after each hidden layer

Choosing between CNN and MLP

CNN MLP
Input type 2-D images (preserves spatial structure) Any flattened vector
Inductive bias Translation equivariance None
Parameters (default) ~856 K ~560 K
Speed Slightly faster on GPU Slightly faster on CPU

For image inputs, CNN is recommended. MLP is useful when inputs are already feature vectors.


SimpleGCN

A two-layer Graph Convolutional Network for graph classification. Designed to work with SyntheticGraphDataset and the packed (B, N, N+feat_dim) tensor format — no PyTorch Geometric required.

Input x: (B, N, N+feat_dim)
  split → adj (B,N,N) + h₀ (B,N,feat_dim)
  â = D̂⁻¹/²(A+I)D̂⁻¹/²              (symmetric normalisation with self-loops)
  h₁ = ReLU(â · GCN₁(h₀))           (B, N, hidden_dim)
  h₂ = ReLU(â · GCN₂(h₁))           (B, N, latent_dim)
  z  = h₂.mean(dim=1)                (B, latent_dim)  — graph-level embedding
  logits = Linear(latent_dim → num_classes)
from shiftkit.models import SimpleGCN

model = SimpleGCN(n_nodes=10, feat_dim=4, latent_dim=64, num_classes=2)

# x shape: (B, N, N+feat_dim)  — first N cols = adjacency, rest = features
z      = model.encode(x)    # (B, 64)
logits = model.classify(z)  # (B, 2)
logits = model(x)           # equivalent

Constructor

Parameter Type Default Description
n_nodes int 10 Number of nodes per graph (must match dataset)
feat_dim int 4 Number of node feature dimensions
latent_dim int 64 Graph-level embedding dimensionality
num_classes int 2 Number of output classes
hidden_dim int 64 Hidden dimensionality of the first GCN layer
dropout float 0.0 Dropout probability between GCN layers

End-to-end example

from shiftkit.data import DataManager
from shiftkit.models import SimpleGCN
from shiftkit.methods import MMDTrainer

dm = DataManager(batch_size=32)
train_src, train_tgt = dm.load("synthetic_graphs", train=True)
test_src,  test_tgt  = dm.load("synthetic_graphs", train=False)

model = SimpleGCN(n_nodes=10, feat_dim=4, latent_dim=64, num_classes=2)

trainer = MMDTrainer(
    model=model,
    source_loader=train_src,
    target_loader=train_tgt,
    mmd_weight=1.0,
    warmup_epochs=5,
    lr=1e-3,
)
history = trainer.fit(epochs=30)
result  = trainer.evaluate(test_tgt, domain="target-test")
print(f"Target accuracy: {result['accuracy']*100:.1f}%")

Using a custom model

Any model that exposes .encode(x) and .classify(z) can be used with MMDTrainer and SourceOnlyTrainer:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder    = nn.Sequential(nn.Flatten(), nn.Linear(784, 64), nn.ReLU())
        self.classifier = nn.Linear(64, 10)

    def encode(self, x):
        return self.encoder(x)

    def classify(self, z):
        return self.classifier(z)

    def forward(self, x):
        return self.classify(self.encode(x))