Source code for simshift.model_selection.iwv

import torch

from simshift.model_selection import register_model_selection_algorithm


[docs]@register_model_selection_algorithm("IWV") def iwv(weights: torch.Tensor, source_val_loss: torch.Tensor, **kwargs) -> torch.Tensor: r""" Importance-Weighted Validation (IWV) algorithm for unsupervised model selection. Estimates model risk through a weighted average of validation losses: .. math:: R_{\mathrm{IWV}}^{(i)} = \frac{1}{N} \sum_{j=1}^{N} w_{ij}\,\ell_{ij} Where for :math:`M` candidate models and :math:`N` validation samples: - :math:`w_{ij}`: Learned weight for model :math:`i` on sample :math:`j` (typically :math:`\sum_j w_{ij} = 1`) - :math:`\ell_{ij}`: Validation loss for model :math:`i` on sample :math:`j` Selects the model with minimal :math:`R_{\mathrm{IWV}}`. **Reference**: Sugiyama et al., *"Importance-Weighted Validation for Robust Model Selection"* https://jmlr.org/papers/v8/sugiyama07a.html Args: weights (Tensor): Tensor of shape :math:`[M, N]`. Learned relative weights for :math:`M` candidate models across :math:`N` validation samples. source_val_loss (Tensor): Tensor of shape :math:`[M, N]`. Validation losses where lower values indicate better model performance. **kwargs: Ignored arguments (maintained for API compatibility) Returns: Tensor: Tensor of shape :math:`[M]`. One-hot encoded selection vector where 1 indicates the chosen model. Example: For 3 models with minimum at index 1, returns :code:`tensor([0., 1., 0.])` """ _ = kwargs # weighted validation loss weighted_loss = weights * source_val_loss iwv_risk = torch.mean(weighted_loss, axis=-1) # only take the model that with minimum iwv min_index = torch.min(iwv_risk, dim=0).indices model_weights = torch.zeros_like(iwv_risk) model_weights[min_index] = 1 return model_weights