select_best_model_ranking¶
-
ampligraph.evaluation.
select_best_model_ranking
(model_class, X_train, X_valid, X_test, param_grid, max_combinations=None, param_grid_random_seed=0, use_filter=True, early_stopping=False, early_stopping_params=None, use_test_for_selection=False, entities_subset=None, corrupt_side='s, o', use_default_protocol=False, retrain_best_model=False, verbose=False)¶ Model selection routine for embedding models via either grid search or random search.
For grid search, pass a fixed
param_grid
and leavemax_combinations
as None so that all combinations will be explored.For random search, delimit
max_combinations
to your computational budget and optionally set some parameters to be callables instead of a list (see the documentation forparam_grid
).Note
Random search is more efficient than grid search as the number of parameters grows [BB12]. It is also a strong baseline against more advanced methods such as Bayesian optimization [LJ18].
The function also retrains the best performing model on the concatenation of training and validation sets.
Note we generate negatives at runtime according to the strategy described in [BUGD+13].
Note
By default, model selection is done with raw MRR for better runtime performance (
use_filter=False
).Parameters: - model_class (class) – The class of the EmbeddingModel to evaluate (TransE, DistMult, ComplEx, etc).
- X_train (ndarray, shape [n, 3]) – An array of training triples.
- X_valid (ndarray, shape [n, 3]) – An array of validation triples.
- X_test (ndarray, shape [n, 3]) – An array of test triples.
- param_grid (dict) –
A grid of hyperparameters to use in model selection. The routine will train a model for each combination of these hyperparameters.
Parameters can be either callables or lists. If callable, it must take no parameters and return a constant value. If any parameter is a callable,
max_combinations
must be set to some value.For example, the learning rate could either be
"lr": [0.1, 0.01]
or"lr": lambda: np.random.uniform(0.01, 0.1)
. - max_combinations (int) – Maximum number of combinations to explore. By default (None) all combinations will be explored, which makes it incompatible with random parameters for random search.
- param_grid_random_seed (int) – Random seed for the parameters that are callables and random.
- use_filter (bool) – If True, will use the entire input dataset X to compute filtered MRR (default: True).
- early_stopping (bool) –
Flag to enable early stopping (default:False).
If set to
True
, the training loop adopts the following early stopping heuristic:- The model will be trained regardless of early stopping for
burn_in
epochs. - Every
check_interval
epochs the method will compute the metric specified incriteria
.
If such metric decreases for
stop_interval
checks, we stop training early.Note the metric is computed on
x_valid
. This is usually a validation set that you held out.Also, because
criteria
is a ranking metric, it requires generating negatives. Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt. The method supports filtered metrics, by passing an array of positives tox_filter
. This will be used to filter the negatives generated on the fly (i.e. the corruptions).Note
Keep in mind the early stopping criteria may introduce a certain overhead (caused by the metric computation). The goal is to strike a good trade-off between such overhead and saving training epochs.
A common approach is to use MRR unfiltered:
early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
Note the size of validation set also contributes to such overhead. In most cases a smaller validation set would be enough.
- The model will be trained regardless of early stopping for
- early_stopping_params (dict) –
Dictionary of parameters for early stopping.
The following keys are supported:
- x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping. Uses X[‘valid’] by default.
- criteria: criteria for early stopping
hits10
,hits3
,hits1
ormrr
. (default) - x_filter: ndarray, shape [n, 3] : Filter to be used(no filter by default)
- burn_in: Number of epochs to pass before kicking in early stopping(default: 100)
- check_interval: Early stopping interval after burn-in(default:10)
- stop_interval: Stop if criteria is performing worse over n consecutive checks (default: 3)
- use_test_for_selection (bool) – Use test set for model selection. If False, uses validation set (default: False).
- entities_subset (array-like) – List of entities to use for corruptions. If None, will generate corruptions using all distinct entities (default: None).
- corrupt_side (string) – Specifies which side to corrupt the entities:
s
is to corrupt only subject.o
is to corrupt only object.s+o
is to corrupt both subject and object.s,o
is to corrupt both subject and object but ranks are computed separately (default). - use_default_protocol (bool) – Flag to indicate whether to evaluate head and tail corruptions separately(default:False). If this is set to true, it will ignore corrupt_side argument and corrupt both head and tail separately and rank triples i.e. corrupt_side=’s,o’ mode.
- retrain_best_model (bool) – Flag to indicate whether best model should be re-trained at the end with the validation set used in the search. Default: False.
- verbose (bool) –
Verbose mode for the model selection procedure (which is independent of the verbose mode in the model fit).
Verbose mode includes display of the progress bar, logging info for each iteration, evaluation information, and exception details.
If you need verbosity inside the model training itself, change the verbose parameter within the
param_grid
.
Returns: - best_model (EmbeddingModel) – The best trained embedding model obtained in model selection.
- best_params (dict) – The hyperparameters of the best embedding model best_model.
- best_mrr_train (float) – The MRR (unfiltered) of the best model computed over the validation set in the model selection loop.
- ranks_test (ndarray, shape [n] or [n,2] depending on the value of corrupt_side.) – An array of ranks of test triples.
When
corrupt_side='s,o'
the function returns [n,2]. The first column represents the rank against subject corruptions and the second column represents the rank against object corruptions. In other cases, it returns [n] i.e. rank against the specified corruptions. - mrr_test (float) – The MRR (filtered) of the best model, retrained on the concatenation of training and validation sets, computed over the test set.
- experimental_history (list of dict) – A list containing all the intermediate experimental results: the model parameters and the corresponding validation metrics.
Examples
>>> from ampligraph.datasets import load_wn18 >>> from ampligraph.latent_features import ComplEx >>> from ampligraph.evaluation import select_best_model_ranking >>> import numpy as np >>> >>> X = load_wn18() >>> >>> model_class = ComplEx >>> param_grid = { >>> "batches_count": [50], >>> "seed": 0, >>> "epochs": [100], >>> "k": [100, 200], >>> "eta": [5, 10, 15], >>> "loss": ["pairwise", "nll"], >>> "loss_params": { >>> "margin": [2] >>> }, >>> "embedding_model_params": { >>> }, >>> "regularizer": ["LP", None], >>> "regularizer_params": { >>> "p": [1, 3], >>> "lambda": [1e-4, 1e-5] >>> }, >>> "optimizer": ["adagrad", "adam"], >>> "optimizer_params": { >>> "lr": lambda: np.random.uniform(0.0001, 0.01) >>> }, >>> "verbose": False >>> } >>> select_best_model_ranking(model_class, X['train'], X['valid'], X['test'], >>> param_grid, >>> max_combinations=100, >>> use_filter=True, >>> verbose=True, >>> early_stopping=True)