select_best_model_ranking

ampligraph.evaluation.select_best_model_ranking(model_class, X, param_grid, use_filter=False, early_stopping=False, early_stopping_params={}, use_test_for_selection=True, rank_against_ent=None, corrupt_side='s+o', use_default_protocol=False, verbose=False)

Model selection routine for embedding models.

Note

Model selection done with raw MRR for better runtime performance.

The function also retrains the best performing model on the concatenation of training and validation sets.

(note that we generate negatives at runtime according to the strategy described in :[BUGD+13]).

Parameters:
  • model_class (class) – The class of the EmbeddingModel to evaluate (TransE, DistMult, ComplEx, etc).
  • X (dict) – A dictionary of triples to use in model selection. Must include three keys: train, val, test. Values are ndarray of shape [n, 3]..
  • param_grid (dict) – A grid of hyperparameters to use in model selection. The routine will train a model for each combination of these hyperparameters.
  • use_filter (bool) – If True, will use the entire input dataset X to compute filtered MRR
  • early_stopping (bool) – Flag to enable early stopping(default:False)
  • early_stopping_params (dict) –

    Dictionary of parameters for early stopping.

    The following keys are supported:

    x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping. Uses X[‘valid’] by default.

    criteria: criteria for early stopping hits10, hits3, hits1 or mrr. (default)

    x_filter: ndarray, shape [n, 3] : Filter to be used(no filter by default)

    burn_in: Number of epochs to pass before kicking in early stopping(default: 100)

    check_interval: Early stopping interval after burn-in(default:10)

    stop_interval: Stop if criteria is performing worse over n consecutive checks (default: 3)

  • use_test_for_selection (bool) – Use test set for model selection. If False, uses validation set. Default(True)
  • rank_against_ent (array-like) – List of entities to use for corruptions. If None, will generate corruptions using all distinct entities. Default is None.
  • corrupt_side (string) – Specifies which side to corrupt the entities. s is to corrupt only subject. o is to corrupt only object s+o is to corrupt both subject and object
  • use_default_protocol (bool) – Flag to indicate whether to evaluate head and tail corruptions separately(default:False). If this is set to true, it will ignore corrupt_side argument and corrupt both head and tail separately and rank triplets.
  • verbose (bool) – Verbose mode during evaluation of trained model
Returns:

  • best_model (EmbeddingModel) – The best trained embedding model obtained in model selection.
  • best_params (dict) – The hyperparameters of the best embedding model best_model.
  • best_mrr_train (float) – The MRR (unfiltered) of the best model computed over the validation set in the model selection loop.
  • ranks_test (ndarray, shape [n]) – The ranks of each triple in the test set X[‘test].
  • mrr_test (float) – The MRR (filtered) of the best model, retrained on the concatenation of training and validation sets, computed over the test set.

Examples

>>> from ampligraph.datasets import load_wn18
>>> from ampligraph.latent_features import ComplEx
>>> from ampligraph.evaluation import select_best_model_ranking
>>>
>>> X = load_wn18()
>>> model_class = ComplEx
>>> param_grid = {
>>>                     "batches_count": [50],
>>>                     "seed": 0,
>>>                     "epochs": [4000],
>>>                     "k": [100, 200],
>>>                     "eta": [5,10,15],
>>>                     "loss": ["pairwise", "nll"],
>>>                     "loss_params": {
>>>                         "margin": [2]
>>>                     },
>>>                     "embedding_model_params": {
>>>
>>>                     },
>>>                     "regularizer": ["LP", None],
>>>                     "regularizer_params": {
>>>                         "p": [1, 3],
>>>                         "lambda": [1e-4, 1e-5]
>>>                     },
>>>                     "optimizer": ["adagrad", "adam"],
>>>                     "optimizer_params":{
>>>                         "lr": [0.01, 0.001, 0.0001]
>>>                     },
>>>                     "verbose": false
>>>                 }
>>> select_best_model_ranking(model_class, X, param_grid, use_filter=True, verbose=True, early_stopping=True)