
ampligraph.evaluation.select_best_model_ranking(model_class, X, param_grid, use_filter=False, early_stopping=False, early_stopping_params={}, use_test_for_selection=True, rank_against_ent=None, corrupt_side='s+o', use_default_protocol=False, verbose=False)

Model selection routine for embedding models.


By default, model selection is done with raw MRR for better runtime performance (use_filter=False).

The function also retrains the best performing model on the concatenation of training and validation sets.

Note we generate negatives at runtime according to the strategy described in :[BUGD+13]).

  • model_class (class) – The class of the EmbeddingModel to evaluate (TransE, DistMult, ComplEx, etc).
  • X (dict) – A dictionary of triples to use in model selection. Must include three keys: train, val, test. Values are ndarray of shape [n, 3]..
  • param_grid (dict) – A grid of hyperparameters to use in model selection. The routine will train a model for each combination of these hyperparameters.
  • use_filter (bool) – If True, will use the entire input dataset X to compute filtered MRR
  • early_stopping (bool) –

    Flag to enable early stopping (default:False).

    If set to True, the training loop adopts the following early stopping heuristic:

    • The model will be trained regardless of early stopping for burn_in epochs.
    • Every check_interval epochs the method will compute the metric specified in criteria.

    If such metric decreases for stop_interval checks, we stop training early.

    Note the metric is computed on x_valid. This is usually a validation set that you held out.

    Also, because criteria is a ranking metric, it requires generating negatives. Entities used to generate corruptions can be specified, as long as the side(s) of a triple to corrupt. The method supports filtered metrics, by passing an array of positives to x_filter. This will be used to filter the negatives generated on the fly (i.e. the corruptions).


    Keep in mind the early stopping criteria may introduce a certain overhead (caused by the metric computation). The goal is to strike a good trade-off between such overhead and saving training epochs.

    A common approach is to use MRR unfiltered:

    early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}

    Note the size of validation set also contributes to such overhead. In most cases a smaller validation set would be enough.

  • early_stopping_params (dict) –

    Dictionary of parameters for early stopping.

    The following keys are supported:

    x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping. Uses X[‘valid’] by default.

    criteria: criteria for early stopping hits10, hits3, hits1 or mrr. (default)

    x_filter: ndarray, shape [n, 3] : Filter to be used(no filter by default)

    burn_in: Number of epochs to pass before kicking in early stopping(default: 100)

    check_interval: Early stopping interval after burn-in(default:10)

    stop_interval: Stop if criteria is performing worse over n consecutive checks (default: 3)

  • use_test_for_selection (bool) – Use test set for model selection. If False, uses validation set. Default(True)
  • rank_against_ent (array-like) – List of entities to use for corruptions. If None, will generate corruptions using all distinct entities. Default is None.
  • corrupt_side (string) – Specifies which side to corrupt the entities. s is to corrupt only subject. o is to corrupt only object s+o is to corrupt both subject and object
  • use_default_protocol (bool) – Flag to indicate whether to evaluate head and tail corruptions separately(default:False). If this is set to true, it will ignore corrupt_side argument and corrupt both head and tail separately and rank triples.
  • verbose (bool) – Verbose mode during evaluation of trained model

  • best_model (EmbeddingModel) – The best trained embedding model obtained in model selection.
  • best_params (dict) – The hyperparameters of the best embedding model best_model.
  • best_mrr_train (float) – The MRR (unfiltered) of the best model computed over the validation set in the model selection loop.
  • ranks_test (ndarray, shape [n]) – The ranks of each triple in the test set X[‘test].
  • mrr_test (float) – The MRR (filtered) of the best model, retrained on the concatenation of training and validation sets, computed over the test set.


>>> from ampligraph.datasets import load_wn18
>>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.evaluation import select_best_model_ranking
>>> X = load_wn18()
>>> model_class = ComplEx
>>> param_grid = {
>>>                     "batches_count": [50],
>>>                     "seed": 0,
>>>                     "epochs": [4000],
>>>                     "k": [100, 200],
>>>                     "eta": [5,10,15],
>>>                     "loss": ["pairwise", "nll"],
>>>                     "loss_params": {
>>>                         "margin": [2]
>>>                     },
>>>                     "embedding_model_params": {
>>>                     },
>>>                     "regularizer": ["LP", None],
>>>                     "regularizer_params": {
>>>                         "p": [1, 3],
>>>                         "lambda": [1e-4, 1e-5]
>>>                     },
>>>                     "optimizer": ["adagrad", "adam"],
>>>                     "optimizer_params":{
>>>                         "lr": [0.01, 0.001, 0.0001]
>>>                     },
>>>                     "verbose": false
>>>                 }
>>> select_best_model_ranking(model_class, X, param_grid, use_filter=True, verbose=True, early_stopping=True)