ampligraph.evaluation.select_best_model_ranking(model_class, X_train, X_valid, X_test, param_grid, max_combinations=None, param_grid_random_seed=0, use_filter=True, early_stopping=False, early_stopping_params=None, use_test_for_selection=False, entities_subset=None, corrupt_side='s, o', use_default_protocol=False, retrain_best_model=False, verbose=False)

Model selection routine for embedding models via either grid search or random search.

For grid search, pass a fixed param_grid and leave max_combinations as None so that all combinations will be explored.

For random search, delimit max_combinations to your computational budget and optionally set some parameters to be callables instead of a list (see the documentation for param_grid).


Random search is more efficient than grid search as the number of parameters grows [BB12]. It is also a strong baseline against more advanced methods such as Bayesian optimization [LJ18].

The function also retrains the best performing model on the concatenation of training and validation sets.

Note we generate negatives at runtime according to the strategy described in [BUGD+13].


By default, model selection is done with raw MRR for better runtime performance (use_filter=False).

  • model_class (class) – The class of the EmbeddingModel to evaluate (TransE, DistMult, ComplEx, etc).

  • X_train (ndarray, shape [n, 3]) – An array of training triples.

  • X_valid (ndarray, shape [n, 3]) – An array of validation triples.

  • X_test (ndarray, shape [n, 3]) – An array of test triples.

  • param_grid (dict) –

    A grid of hyperparameters to use in model selection. The routine will train a model for each combination of these hyperparameters.

    Parameters can be either callables or lists. If callable, it must take no parameters and return a constant value. If any parameter is a callable, max_combinations must be set to some value.

    For example, the learning rate could either be "lr": [0.1, 0.01] or "lr": lambda: np.random.uniform(0.01, 0.1).

  • max_combinations (int) – Maximum number of combinations to explore. By default (None) all combinations will be explored, which makes it incompatible with random parameters for random search.

  • param_grid_random_seed (int) – Random seed for the parameters that are callables and random.

  • use_filter (bool) – If True, will use the entire input dataset X to compute filtered MRR (default: True).

  • early_stopping (bool) –

    Flag to enable early stopping (default:False).

    If set to True, the training loop adopts the following early stopping heuristic:

    • The model will be trained regardless of early stopping for burn_in epochs.

    • Every check_interval epochs the method will compute the metric specified in criteria.

    If such metric decreases for stop_interval checks, we stop training early.

    Note the metric is computed on x_valid. This is usually a validation set that you held out.

    Also, because criteria is a ranking metric, it requires generating negatives. Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt. The method supports filtered metrics, by passing an array of positives to x_filter. This will be used to filter the negatives generated on the fly (i.e. the corruptions).


    Keep in mind the early stopping criteria may introduce a certain overhead (caused by the metric computation). The goal is to strike a good trade-off between such overhead and saving training epochs.

    A common approach is to use MRR unfiltered:

    early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}

    Note the size of validation set also contributes to such overhead. In most cases a smaller validation set would be enough.

  • early_stopping_params (dict) –

    Dictionary of parameters for early stopping.

    The following keys are supported:

    • x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping. Uses X[‘valid’] by default.

    • criteria: criteria for early stopping hits10, hits3, hits1 or mrr. (default)

    • x_filter: ndarray, shape [n, 3] : Filter to be used(no filter by default)

    • burn_in: Number of epochs to pass before kicking in early stopping(default: 100)

    • check_interval: Early stopping interval after burn-in(default:10)

    • stop_interval: Stop if criteria is performing worse over n consecutive checks (default: 3)

  • use_test_for_selection (bool) – Use test set for model selection. If False, uses validation set (default: False).

  • entities_subset (array-like) – List of entities to use for corruptions. If None, will generate corruptions using all distinct entities (default: None).

  • corrupt_side (string) – Specifies which side to corrupt the entities: s is to corrupt only subject. o is to corrupt only object. s+o is to corrupt both subject and object. s,o is to corrupt both subject and object but ranks are computed separately (default).

  • use_default_protocol (bool) – Flag to indicate whether to evaluate head and tail corruptions separately(default:False). If this is set to true, it will ignore corrupt_side argument and corrupt both head and tail separately and rank triples i.e. corrupt_side=’s,o’ mode.

  • retrain_best_model (bool) – Flag to indicate whether best model should be re-trained at the end with the validation set used in the search. Default: False.

  • verbose (bool) –

    Verbose mode for the model selection procedure (which is independent of the verbose mode in the model fit).

    Verbose mode includes display of the progress bar, logging info for each iteration, evaluation information, and exception details.

    If you need verbosity inside the model training itself, change the verbose parameter within the param_grid.


  • best_model (EmbeddingModel) – The best trained embedding model obtained in model selection.

  • best_params (dict) – The hyperparameters of the best embedding model best_model.

  • best_mrr_train (float) – The MRR (unfiltered) of the best model computed over the validation set in the model selection loop.

  • ranks_test (ndarray, shape [n] or [n,2] depending on the value of corrupt_side.) – An array of ranks of test triples. When corrupt_side='s,o' the function returns [n,2]. The first column represents the rank against subject corruptions and the second column represents the rank against object corruptions. In other cases, it returns [n] i.e. rank against the specified corruptions.

  • mrr_test (float) – The MRR (filtered) of the best model, retrained on the concatenation of training and validation sets, computed over the test set.

  • experimental_history (list of dict) – A list containing all the intermediate experimental results: the model parameters and the corresponding validation metrics.


>>> from ampligraph.datasets import load_wn18
>>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.evaluation import select_best_model_ranking
>>> import numpy as np
>>> X = load_wn18()
>>> model_class = ComplEx
>>> param_grid = {
>>>     "batches_count": [50],
>>>     "seed": 0,
>>>     "epochs": [100],
>>>     "k": [100, 200],
>>>     "eta": [5, 10, 15],
>>>     "loss": ["pairwise", "nll"],
>>>     "loss_params": {
>>>         "margin": [2]
>>>     },
>>>     "embedding_model_params": {
>>>     },
>>>     "regularizer": ["LP", None],
>>>     "regularizer_params": {
>>>         "p": [1, 3],
>>>         "lambda": [1e-4, 1e-5]
>>>     },
>>>     "optimizer": ["adagrad", "adam"],
>>>     "optimizer_params": {
>>>         "lr": lambda: np.random.uniform(0.0001, 0.01)
>>>     },
>>>     "verbose": False
>>> }
>>> select_best_model_ranking(model_class, X['train'], X['valid'], X['test'],
>>>                           param_grid,
>>>                           max_combinations=100,
>>>                           use_filter=True,
>>>                           verbose=True,
>>>                           early_stopping=True)