Model Selection
DEV
- simshift.model_selection.dev.dev(weights: torch.Tensor, source_val_loss: torch.Tensor, **kwargs) torch.Tensor[source]
Deep Embedded Validation (DEV) algorithm for unsupervised model selection.
Computes a corrected risk estimate by modeling the dependency between learned sample weights (estimated density ratios) and validation losses:
\[\eta = -\frac{\mathrm{Cov}(w, \ell)}{\mathrm{Var}(w)}, \quad R_{\mathrm{DEV}} = \mathbb{E}[w \ell] + \eta \mathbb{E}[w] - \eta\]Where:
\(w\) (weights): shape \([M, N]\) for \(M\) models, \(N\)
validation samples - \(\ell\) (source_val_loss): shape \([M, N]\), validation losses - Expectations and variances are computed over the sample dimension
Selects the model with minimal \(R_{\mathrm{DEV}}\).
Reference: You et al. “Towards Accurate Model Selection in Deep Unsupervised Domain Adaptation” https://proceedings.mlr.press/v97/you19a.html
- Parameters:
weights (Tensor) – Tensor of shape \([M, N]\). Learned relative weights for \(M\) candidate models across \(N\) validation samples.
source_val_loss (Tensor) – Tensor of shape \([M, N]\). Validation losses where lower values indicate better model performance.
**kwargs – Additional keyword arguments (ignored, maintained for API compatibility)
- Returns:
Tensor of shape \([M]\). One-hot encoded selection vector where 1 indicates the chosen model. Example: For 3 models with minimum at index 1, returns
tensor([0., 1., 0.])- Return type:
Tensor
IWV
- simshift.model_selection.iwv.iwv(weights: torch.Tensor, source_val_loss: torch.Tensor, **kwargs) torch.Tensor[source]
Importance-Weighted Validation (IWV) algorithm for unsupervised model selection.
Estimates model risk through a weighted average of validation losses:
\[R_{\mathrm{IWV}}^{(i)} = \frac{1}{N} \sum_{j=1}^{N} w_{ij}\,\ell_{ij}\]Where for \(M\) candidate models and \(N\) validation samples:
\(w_{ij}\): Learned weight for model \(i\) on sample \(j\)
(typically \(\sum_j w_{ij} = 1\)) - \(\ell_{ij}\): Validation loss for model \(i\) on sample \(j\)
Selects the model with minimal \(R_{\mathrm{IWV}}\).
Reference: Sugiyama et al., “Importance-Weighted Validation for Robust Model Selection” https://jmlr.org/papers/v8/sugiyama07a.html
- Parameters:
weights (Tensor) – Tensor of shape \([M, N]\). Learned relative weights for \(M\) candidate models across \(N\) validation samples.
source_val_loss (Tensor) – Tensor of shape \([M, N]\). Validation losses where lower values indicate better model performance.
**kwargs – Ignored arguments (maintained for API compatibility)
- Returns:
Tensor of shape \([M]\). One-hot encoded selection vector where 1 indicates the chosen model. Example: For 3 models with minimum at index 1, returns
tensor([0., 1., 0.])- Return type:
Tensor
SB
- simshift.model_selection.source_best.source_best(source_val_loss, **kwargs)[source]
Naive model selection strategy that selects the model based on validation loss in the source domain.
- Parameters:
source_val_loss (torch.Tensor) – A tensor of shape [n_models, n_samples]
representing – validation losses of each model on source domain validation data.
**kwargs – Additional keyword arguments (ignored in this method).
- Returns:
- A one-hot tensor of shape [n_models], where the index
corresponding to the model with the lowest mean validation loss in the source domain is set to 1, and all others are 0.
- Return type:
torch.Tensor
TB
- simshift.model_selection.target_best.target_best(target_test_loss, **kwargs)[source]
Selects the best model for each sample based on target test loss. Should only be used as a theoretical lower bound for other model selection strategies since it uses the test set.
- Parameters:
target_test_loss (torch.Tensor) – A tensor of shape [n_models, n_samples] representing the loss of each model on each sample from the target domain.
**kwargs – Additional keyword arguments (ignored in this method).
- Returns:
- A one-hot tensor of shape [n_models], where the index
corresponding to the model with the lowest mean validation loss on the testset is set to 1, and all others are 0.
- Return type:
torch.Tensor