ComplEx¶
-
class
ampligraph.latent_features.
ComplEx
(k=100, eta=2, epochs=100, batches_count=100, seed=0, embedding_model_params={}, optimizer='adagrad', optimizer_params={'lr': 0.1}, loss='nll', loss_params={}, regularizer=None, regularizer_params={}, model_checkpoint_path='saved_model/', verbose=False, **kwargs)¶ Complex embeddings (ComplEx)
The ComplEx model [TWR+16] is an extension of the
ampligraph.latent_features.DistMult
bilinear diagonal model . ComplEx scoring function is based on the trilinear Hermitian dot product in \(\mathcal{C}\):\[f_{ComplEx}=Re(\langle \mathbf{r}_p, \mathbf{e}_s, \overline{\mathbf{e}_o} \rangle)\]Note that because embeddings are in \(\mathcal{C}\), ComplEx uses twice as many parameters as its counterpart in \(\mathcal{R}\) DistMult.
Examples
>>> import numpy as np >>> from ampligraph.latent_features import ComplEx >>> >>> model = ComplEx(batches_count=1, seed=555, epochs=20, k=10, >>> loss='pairwise', loss_params={'margin':1}, >>> regularizer='LP', regularizer_params={'lambda':0.1}) >>> X = np.array([['a', 'y', 'b'], >>> ['b', 'y', 'a'], >>> ['a', 'y', 'c'], >>> ['c', 'y', 'a'], >>> ['a', 'y', 'd'], >>> ['c', 'y', 'd'], >>> ['b', 'y', 'c'], >>> ['f', 'y', 'e']]) >>> model.fit(X) >>> model.predict(np.array([['f', 'y', 'e'], ['b', 'y', 'd']])) [0.96325016, -0.17629346] >>> model.get_embeddings(['f','e'], type='entity') array([[-0.11257 , -0.09226837, 0.2829331 , -0.02094189, 0.02826234, -0.3068198 , -0.41022655, -0.23714773, -0.00084166, 0.22521858, -0.48155236, 0.29627186, 0.29841757, 0.16540456, 0.45836073, 0.14025007, -0.03458257, -0.03813137, 0.35438442, -0.4733188 ], [ 0.06088537, 0.13615245, -0.20476362, 0.20391239, 0.22199424, 0.5762486 , -0.01087974, 0.39070424, -0.1372974 , 0.39998057, -0.5944237 , 0.506474 , 0.1255992 , -0.06021457, -0.26678884, -0.18713273, 0.36862013, 0.07165384, -0.00845572, -0.16494963]], dtype=float32)
Methods
__init__
([k, eta, epochs, batches_count, …])Initialize an EmbeddingModel fit
(X[, early_stopping, early_stopping_params])Train a ComplEx model. get_embeddings
(entities[, type])Get the embeddings of entities or relations. predict
(X[, from_idx, get_ranks])Predict the score of triples using a trained embedding model. -
__init__
(k=100, eta=2, epochs=100, batches_count=100, seed=0, embedding_model_params={}, optimizer='adagrad', optimizer_params={'lr': 0.1}, loss='nll', loss_params={}, regularizer=None, regularizer_params={}, model_checkpoint_path='saved_model/', verbose=False, **kwargs)¶ Initialize an EmbeddingModel
Also creates a new Tensorflow session for training.Parameters: - k (int) – Embedding space dimensionality
- eta (int) – The number of negatives that must be generated at runtime during training for each positive.
- epochs (int) – The iterations of the training loop.
- batches_count (int) – The number of batches in which the training set must be split during the training loop.
- seed (int) – The seed used by the internal random numbers generator.
- embedding_model_params (dict) – ComplEx-specific hyperparams: Currently ComplEx does not require any hyperparameters.
- optimizer (string) – The optimizer used to minimize the loss function. Choose between
sgd
,adagrad
,adam
,momentum
. - optimizer_params (dict) –
Parameters values specific to the optimizer. Currently supported:
- lr - learning rate (used by all the optimizers)
- momentum - learning momentum (used by momentum optimizer)
- loss (string) –
The type of loss function to use during training.
pairwise
the model will use pairwise margin-based loss function.nll
the model will use negative loss likelihood.absolute_margin
the model will use absolute margin likelihood.self_adversarial
the model will use adversarial sampling loss function.
- loss_params (dict) –
Parameters dictionary specific to the loss.
(Refer documentation of specific loss functions for more details)
- regularizer (string) –
The regularization strategy to use with the loss function.
LP
the model will use L1, L2 or L3 based on the value passed to param p.None
the model will not use any regularizer
- regularizer_params (dict) –
Parameters dictionary specific to the regularizer.
(Refer documentation of regularizer for more details)
- model_checkpoint_path (string) – Path to save the model.
- verbose (bool) – Verbose mode
- kwargs (dict) – Additional inputs, if any
-
fit
(X, early_stopping=False, early_stopping_params={})¶ Train a ComplEx model.
The model is trained on a training set X using the training protocol described in [TWR+16].Parameters: - X (ndarray, shape [n, 3]) – The training triples
- early_stopping (bool) – Flag to enable early stopping(default:False)
- early_stopping_params (dictionary) –
Dictionary of parameters for early stopping. Following keys are supported:
- x_valid: ndarray, shape [n, 3] : Validation set to be used for early stopping.
- criteria: string : criteria for early stopping
hits10
,hits3
,hits1
ormrr
(default). - x_filter: ndarray, shape [n, 3] : Filter to be used (no filter by default).
- burn_in: int : Number of epochs to pass before kicking in early stopping (default: 100).
- check_interval: int : Early stopping interval after burn-in (default:10).
- stop_interval: int : Stop if criteria is performing worse over n consecutive checks (default: 3).
-
get_embeddings
(entities, type='entity')¶ Get the embeddings of entities or relations.
Parameters: - entities (array-like, dtype=int, shape=[n]) – The entities (or relations) of interest. Element of the vector must be the original string literals, and not internal IDs.
- type (string) – If ‘entity’, will consider input as KG entities. If relation, they will be treated as KG predicates.
Returns: embeddings – An array of k-dimensional embeddings.
Return type: ndarray, shape [n, k]
-
predict
(X, from_idx=False, get_ranks=False)¶ Predict the score of triples using a trained embedding model.
The function returns raw scores generated by the model. To obtain probability estimates, use a logistic sigmoid.Parameters: - X (ndarray, shape [n, 3]) – The triples to score.
- from_idx (bool) – If True, will skip conversion to internal IDs. (default: False).
- get_ranks (bool) – Flag to compute ranks by scoring against corruptions (default: False).
Returns: - scores_predict (ndarray, shape [n]) – The predicted scores for input triples X.
- rank (ndarray, shape [n]) – Rank of the triple
-