Regularizer

class ampligraph.latent_features.Regularizer(hyperparam_dict, verbose=False)

Abstract class for Regularizer.

Methods

__init__(hyperparam_dict[, verbose]) Initialize the regularizer.
get_state(param_name) Get the state value.
_init_hyperparams(hyperparam_dict) Initializes the hyperparameters needed by the algorithm.
apply(trainable_params) Interface to external world.
_apply(trainable_params) Apply the regularization function.
__init__(hyperparam_dict, verbose=False)

Initialize the regularizer.

Parameters:hyperparam_dict (dict) – dictionary of hyperparams (Keys are described in the hyperparameters section)
get_state(param_name)

Get the state value.

Parameters:param_name (string) – name of the state for which one wants to query the value
Returns:the value of the corresponding state
Return type:param_value
_init_hyperparams(hyperparam_dict)

Initializes the hyperparameters needed by the algorithm.

Parameters:hyperparam_dict (dictionary) – Consists of key value pairs. The regularizer will check the keys to get the corresponding params
apply(trainable_params)

Interface to external world. This function performs input checks, input pre-processing, and and applies the loss function.

Parameters:trainable_params (list, shape [n]) – List of trainable params that should be reqularized
Returns:loss – Regularization Loss
Return type:tf.Tensor
_apply(trainable_params)

Apply the regularization function. Every inherited class must implement this function.

(All the TF code must go in this function.)

Parameters:trainable_params (list, shape [n]) – List of trainable params that should be reqularized
Returns:loss – Regularization Loss
Return type:tf.Tensor