#!/usr/bin/env python
# encoding: utf-8
import torch
from torchctr.layers import EmbeddingLayer, LinearLayer
from torchctr.models.checker import Checker
[docs]class FactorizationMachineLayer(torch.nn.Module):
## second order
def __init__(self, reduce_sum=True):
super().__init__()
self.reduce_sum = reduce_sum
[docs] def forward(self, x):
square_of_sum = torch.sum(x, dim=1) ** 2
sum_of_square = torch.sum(x ** 2, dim=1)
res = square_of_sum - sum_of_square
if self.reduce_sum:
return 0.5 * torch.sum(res, dim=1, keepdim=True)
else:
return res
[docs]class FactorizationMachine(torch.nn.Module):
"""
FactorizationMachine Model
"""
@Checker.model_param_check
def __init__(self, feature_dims, embed_dim):
super().__init__()
self.embedding = EmbeddingLayer(feature_dims, embed_dim)
self.linear = LinearLayer(feature_dims)
self.fm = FactorizationMachineLayer(reduce_sum=True)
[docs] def forward(self, x, sigmoid=True):
# add the first order and second order
fm_full = self.linear(x) + self.fm(self.embedding(x))
if sigmoid:
return torch.sigmoid(fm_full.squeeze(1))
def __repr__(self):
return self.__class__.__name__