Source code for simshift.model_selection.source_best

import torch

from simshift.model_selection import register_model_selection_algorithm


[docs]@register_model_selection_algorithm("SB") def source_best(source_val_loss, **kwargs): """Naive model selection strategy that selects the model based on validation loss in the source domain. Parameters: source_val_loss (torch.Tensor): A tensor of shape [n_models, n_samples] representing validation losses of each model on source domain validation data. **kwargs: Additional keyword arguments (ignored in this method). Returns: torch.Tensor: A one-hot tensor of shape [n_models], where the index corresponding to the model with the lowest mean validation loss in the source domain is set to 1, and all others are 0. """ _ = kwargs # weights shape: [n_models, n_samples] # source_val_loss shape: [n_models, n_samples] mean_loss = torch.mean(source_val_loss, dim=1) # [n_models] # select model with minimum dev risk min_index = torch.argmin(mean_loss) model_weights = torch.zeros_like(mean_loss) model_weights[min_index] = 1 return model_weights