Source code for torchctr.models.wide_and_deep

#!/usr/bin/env python
# encoding: utf-8

import torch

from torchctr.layers import LinearLayer, EmbeddingLayer, MultiLayerPerceptron

[docs]class WideAndDeepModel(torch.nn.Module): def __init__(self, feature_dims, embed_dim, hidden_dims): super().__init__() self.linear = LinearLayer(feature_dims) self.embedding = EmbeddingLayer(feature_dims, embed_dim) self.mlp_input_dim = embed_dim * len(feature_dims) self.mlp = MultiLayerPerceptron(input_dim = self.mlp_input_dim, hidden_dims = hidden_dims, output_dim = 1)
[docs] def forward(self, x, sigmoid=True): linear_part = self.linear(x) mlp_part = self.mlp(self.embedding(x).view(-1, self.mlp_input_dim)) wide_and_deep = linear_part + mlp_part if sigmoid: return torch.sigmoid(wide_and_deep.squeeze(1))