hopwise.model.context_aware_recommender.kd_dagfm

Reference:

Zhen Tian et al. “Directed Acyclic Graph Factorization Machines for CTR Prediction via Knowledge Distillation.” in WSDM 2023.

Reference code:

https://github.com/chenyuwuxin/DAGFM

Classes

KD_DAGFM

KD_DAGFM is a context-based recommendation model. The model is based on directed acyclic graph and knowledge

DAGFM

CrossNet

CINComp

CIN

Module Contents

class hopwise.model.context_aware_recommender.kd_dagfm.KD_DAGFM(config, dataset)

Bases: hopwise.model.abstract_recommender.ContextRecommender

KD_DAGFM is a context-based recommendation model. The model is based on directed acyclic graph and knowledge distillation. It can learn arbitrary feature interactions from the complex teacher networks and achieve approximately lossless model performance. It can also greatly reduce the computational resource costs.

phase
alpha
beta
student_network
teacher_network
loss_fn
get_teacher_config(config)
FeatureInteraction(feature)
forward(interaction)
calculate_loss(interaction)

Calculate the training loss for a batch data.

Parameters:

interaction (Interaction) – Interaction class of the batch.

Returns:

Training loss, shape: []

Return type:

torch.Tensor

predict(interaction)

Predict the scores between users and items.

Parameters:

interaction (Interaction) – Interaction class of the batch.

Returns:

Predicted scores for given users and items, shape: [batch_size]

Return type:

torch.Tensor

class hopwise.model.context_aware_recommender.kd_dagfm.DAGFM(config)

Bases: torch.nn.Module

type
depth
adj_matrix
connect_layer
linear
FeatureInteraction(feature)
class hopwise.model.context_aware_recommender.kd_dagfm.CrossNet(config)

Bases: torch.nn.Module

depth
embedding_size
feature_num
in_feature_num
cross_layer_w
bias
linear
FeatureInteraction(x_0)
forward(feature)
class hopwise.model.context_aware_recommender.kd_dagfm.CINComp(indim, outdim, config)

Bases: torch.nn.Module

conv
forward(feature, base)
class hopwise.model.context_aware_recommender.kd_dagfm.CIN(config)

Bases: torch.nn.Module

cinlist
cin
linear
backbone = ['cin', 'linear']
loss_fn
FeatureInteraction(feature)
forward(feature)