select_best_model_ranking¶
-
ampligraph.evaluation.
select_best_model_ranking
(model_class, X, param_grid, use_filter=False, early_stopping=False, early_stopping_params={}, use_test_for_selection=True, rank_against_ent=None, corrupt_side='s+o', use_default_protocol=False, verbose=False)¶ Model selection routine for embedding models.
Note
By default, model selection is done with raw MRR for better runtime performance (
use_filter=False
).The function also retrains the best performing model on the concatenation of training and validation sets.
Note we generate negatives at runtime according to the strategy described in :[BUGD+13]).
Parameters: - model_class (class) – The class of the EmbeddingModel to evaluate (TransE, DistMult, ComplEx, etc).
- X (dict) – A dictionary of triples to use in model selection. Must include three keys: train, val, test. Values are ndarray of shape [n, 3]..
- param_grid (dict) – A grid of hyperparameters to use in model selection. The routine will train a model for each combination of these hyperparameters.
- use_filter (bool) – If True, will use the entire input dataset X to compute filtered MRR
- early_stopping (bool) –
Flag to enable early stopping (default:False).
If set to
True
, the training loop adopts the following early stopping heuristic:- The model will be trained regardless of early stopping for
burn_in
epochs. - Every
check_interval
epochs the method will compute the metric specified incriteria
.
If such metric decreases for
stop_interval
checks, we stop training early.Note the metric is computed on
x_valid
. This is usually a validation set that you held out.Also, because
criteria
is a ranking metric, it requires generating negatives. Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt. The method supports filtered metrics, by passing an array of positives tox_filter
. This will be used to filter the negatives generated on the fly (i.e. the corruptions).Note
Keep in mind the early stopping criteria may introduce a certain overhead (caused by the metric computation). The goal is to strike a good trade-off between such overhead and saving training epochs.
A common approach is to use MRR unfiltered:
early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}
Note the size of validation set also contributes to such overhead. In most cases a smaller validation set would be enough.
- The model will be trained regardless of early stopping for
- early_stopping_params (dict) –
Dictionary of parameters for early stopping.
The following keys are supported:
x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping. Uses X[‘valid’] by default.criteria: criteria for early stopping
hits10
,hits3
,hits1
ormrr
. (default)x_filter: ndarray, shape [n, 3] : Filter to be used(no filter by default)
burn_in: Number of epochs to pass before kicking in early stopping(default: 100)
check_interval: Early stopping interval after burn-in(default:10)
stop_interval: Stop if criteria is performing worse over n consecutive checks (default: 3)
- use_test_for_selection (bool) – Use test set for model selection. If False, uses validation set. Default(True)
- rank_against_ent (array-like) – List of entities to use for corruptions. If None, will generate corruptions using all distinct entities. Default is None.
- corrupt_side (string) – Specifies which side to corrupt the entities.
s
is to corrupt only subject.o
is to corrupt only objects+o
is to corrupt both subject and object - use_default_protocol (bool) – Flag to indicate whether to evaluate head and tail corruptions separately(default:False). If this is set to true, it will ignore corrupt_side argument and corrupt both head and tail separately and rank triples.
- verbose (bool) – Verbose mode during evaluation of trained model
Returns: - best_model (EmbeddingModel) – The best trained embedding model obtained in model selection.
- best_params (dict) – The hyperparameters of the best embedding model best_model.
- best_mrr_train (float) – The MRR (unfiltered) of the best model computed over the validation set in the model selection loop.
- ranks_test (ndarray, shape [n]) – The ranks of each triple in the test set X[‘test].
- mrr_test (float) – The MRR (filtered) of the best model, retrained on the concatenation of training and validation sets, computed over the test set.
Examples
>>> from ampligraph.datasets import load_wn18 >>> from ampligraph.latent_features import ComplEx >>> from ampligraph.evaluation import select_best_model_ranking >>> >>> X = load_wn18() >>> model_class = ComplEx >>> param_grid = { >>> "batches_count": [50], >>> "seed": 0, >>> "epochs": [4000], >>> "k": [100, 200], >>> "eta": [5,10,15], >>> "loss": ["pairwise", "nll"], >>> "loss_params": { >>> "margin": [2] >>> }, >>> "embedding_model_params": { >>> >>> }, >>> "regularizer": ["LP", None], >>> "regularizer_params": { >>> "p": [1, 3], >>> "lambda": [1e-4, 1e-5] >>> }, >>> "optimizer": ["adagrad", "adam"], >>> "optimizer_params":{ >>> "lr": [0.01, 0.001, 0.0001] >>> }, >>> "verbose": false >>> } >>> select_best_model_ranking(model_class, X, param_grid, use_filter=True, verbose=True, early_stopping=True)