Skip to content

Data

shiftkit.data provides dataset loading and paired source/target DataLoader creation.


DataManager

Central hub for loading source/target domain data. Maintains a registry of dataset-pair factories and returns paired DataLoader objects.

from shiftkit.data import DataManager

dm = DataManager(root="./data", batch_size=64)
train_src, train_tgt = dm.load("mnist_noisy_mnist", train=True)
test_src,  test_tgt  = dm.load("mnist_noisy_mnist", train=False)

Constructor

Parameter Type Default Description
root str "./data" Root directory where datasets are downloaded
batch_size int 64 Batch size for both DataLoaders
num_workers int 0 Number of DataLoader worker processes

Methods

load(name, train=True, **kwargs)

Return (source_loader, target_loader) for the named dataset pair.

Parameter Type Default Description
name str Registered dataset key, e.g. "mnist_noisy_mnist"
train bool True Load training split if True, test split if False
**kwargs Forwarded to the factory (e.g. noise_std=0.5)

Returns: (DataLoader, DataLoader) — source loader, target loader

Raises: ValueError if name is not registered.

register(name, factory) (static)

Register a custom dataset-pair factory.

def my_factory(root, batch_size, train, num_workers, **kwargs):
    source_ds = ...  # your source dataset
    target_ds = ...  # your target dataset
    source_loader = DataLoader(source_ds, batch_size=batch_size, shuffle=train)
    target_loader = DataLoader(target_ds, batch_size=batch_size, shuffle=train)
    return source_loader, target_loader

DataManager.register("my_pair", my_factory)

# then use it:
dm = DataManager()
src, tgt = dm.load("my_pair")
Parameter Type Description
name str Key to register under
factory callable Function with signature (root, batch_size, train, num_workers, **kwargs)(DataLoader, DataLoader)

available() (static)

Return a list of all registered dataset-pair names.

print(DataManager.available())
# ['mnist_noisy_mnist', 'my_pair', ...]

NoisyMNIST

A torch.utils.data.Dataset that wraps torchvision.datasets.MNIST and adds per-sample Gaussian noise. Used as the built-in synthetic target domain.

from shiftkit.data.datasets import NoisyMNIST

ds = NoisyMNIST(root="./data", train=True, noise_std=0.3)
img, label = ds[0]   # img is a clipped noisy tensor in [0, 1]

Constructor

Parameter Type Default Description
root str Path to dataset directory
train bool True Training split if True, test split if False
noise_std float 0.3 Standard deviation of additive Gaussian noise
transform callable None Additional transforms applied after noise injection
download bool True Download the dataset if not present

Note

Noise is injected as img = (img + N(0, noise_std²)).clamp(0, 1) each time __getitem__ is called, so each epoch sees different noise realisations.


SyntheticGraphDataset

A purely synthetic graph classification dataset designed as a minimal benchmark for graph-based domain adaptation. No download required — graphs are generated on the fly.

Each sample is a small graph with N=10 nodes and 2 classes:

Class Generator Structure
0 Stochastic Block Model (SBM) Two tightly-connected communities (p_in=0.7, p_out=0.05)
1 Erdős–Rényi (ER) Uniform random edges (p=0.25)

Domain shift is introduced by two mechanisms in the target domain:

  • Feature noise: node features have higher Gaussian noise (σ=0.5 vs σ=0.1 in the source)
  • Edge flips: each edge is independently flipped with probability 5%

Tensor format — each sample x has shape (N, N + feat_dim). The first N columns are the adjacency matrix; the remaining columns are node features [degree / (N-1), class-offset Gaussian noise]. This packed format is compatible with standard DataLoader batching without requiring PyTorch Geometric.

from shiftkit.data.datasets import SyntheticGraphDataset
from torch.utils.data import DataLoader

src_ds = SyntheticGraphDataset(feature_noise=0.1, edge_flip_prob=0.0)
tgt_ds = SyntheticGraphDataset(feature_noise=0.5, edge_flip_prob=0.05, seed=99)

src_loader = DataLoader(src_ds, batch_size=32, shuffle=True)
x, y = next(iter(src_loader))
# x: (32, 10, 14)   — 10 nodes, 10-col adj + 4-col features
# y: (32,)          — class labels {0, 1}

Or via DataManager:

dm = DataManager(batch_size=32)
train_src, train_tgt = dm.load("synthetic_graphs", train=True)
test_src,  test_tgt  = dm.load("synthetic_graphs", train=False)

Constructor

Parameter Type Default Description
n_graphs int 1000 Total number of graphs (split 80/20 train/test)
n_nodes int 10 Number of nodes per graph
feat_dim int 4 Number of node feature dimensions
feature_noise float 0.1 Std dev of additive Gaussian feature noise
edge_flip_prob float 0.0 Probability of flipping each edge
p_in float 0.7 SBM within-community edge probability (class 0)
p_out float 0.05 SBM between-community edge probability (class 0)
p_er float 0.25 Erdős–Rényi edge probability (class 1)
train bool True If True return training split (first 80%), else test split
seed int 42 Random seed for reproducibility

Built-in dataset pairs

Key Source Target Extra kwargs
"mnist_noisy_mnist" torchvision.MNIST NoisyMNIST noise_std (default 0.3)
"synthetic_graphs" SyntheticGraphDataset (noise=0.1) SyntheticGraphDataset (noise=0.5, flip=0.05) n_graphs, n_nodes, feat_dim, feature_noise_src, feature_noise_tgt, edge_flip_prob