Source code for torchctr.models.deep_factorization_machine

#!/usr/bin/env python
# encoding: utf-8

import torch
from torchctr.layers import LinearLayer, EmbeddingLayer, MultiLayerPerceptron
from torchctr.models.factorization_machine import FactorizationMachineLayer
from torchctr.models.checker import Checker


[docs]class DeepFactorizationMachine(torch.nn.Module): @Checker.model_param_check def __init__(self, feature_dims, embed_dim, hidden_dims): super().__init__() self.fm_second_order = FactorizationMachineLayer() self.fm_linear = LinearLayer(feature_dims) self.embedding = EmbeddingLayer(feature_dims, embed_dim) self.mlp_input_dim = embed_dim * len(feature_dims) self.mlp = MultiLayerPerceptron( input_dim=self.mlp_input_dim, hidden_dims=hidden_dims, output_dim=1 )
[docs] def forward(self, x, sigmoid=True): fm_part = self.fm_second_order(self.embedding(x)) + self.fm_linear(x) deep_part = self.mlp(self.embedding(x).view(-1, self.mlp_input_dim)) deep_fm = fm_part + deep_part if sigmoid: return torch.sigmoid(deep_fm.squeeze(1))