hopwise.model.knowledge_graph_embedding_recommender.complex

Reference:

Trouillon et al. “Complex embeddings for simple link prediction.” in ICML’16.

Reference code:

https://github.com/torchkge-team/torchkge

Classes

ComplEx

ComplEx extends DistMult by introducing complex-valued embeddings.

Module Contents

class hopwise.model.knowledge_graph_embedding_recommender.complex.ComplEx(config, dataset)

Bases: hopwise.model.abstract_recommender.KnowledgeRecommender

ComplEx extends DistMult by introducing complex-valued embeddings.

Note

In this version, we sample recommender data and knowledge data separately, and put them together for training.

input_type
embedding_size
device
ui_relation
user_re_embedding
user_im_embedding
entity_re_embedding
entity_im_embedding
relation_re_embedding
relation_im_embedding
loss
forward(head_re_e, head_im_e, rec_r_re_e, rec_r_im_e, tail_re_e, tail_im_e)
triple_dot(x, y, z)
_get_rec_embeddings(user, positive_items, negative_items)
_get_kg_embeddings(head, relation, positive_tails, negative_tails)
calculate_loss(interaction)

Calculate the training loss for a batch data.

Parameters:

interaction (Interaction) – Interaction class of the batch.

Returns:

Training loss, shape: []

Return type:

torch.Tensor

predict(interaction)

Predict the scores between users and items.

Parameters:

interaction (Interaction) – Interaction class of the batch.

Returns:

Predicted scores for given users and items, shape: [batch_size]

Return type:

torch.Tensor

predict_kg(interaction)
full_sort_predict(interaction)

Full sort prediction function. Given users, calculate the scores between users and all candidate items.

Parameters:

interaction (Interaction) – Interaction class of the batch.

Returns:

Predicted scores for given users and all candidate items, shape: [n_batch_users * n_candidate_items]

Return type:

torch.Tensor

full_sort_predict_kg(interaction)

Full sort prediction KG function. Given heads, calculate the scores between heads and all candidate tails.

Parameters:

interaction (Interaction) – Interaction class of the batch.

Returns:

Predicted scores for given heads and all candidate tails, shape: [n_batch_heads * n_candidate_tails]

Return type:

torch.Tensor