Source code for torchctr.layers.linear

#!/usr/bin/env python
# encoding: utf-8

import numpy as np
import torch


[docs]class LinearLayer(torch.nn.Module): def __init__(self, num_features, output_dim=1): super().__init__() self.weights_embed = torch.nn.Embedding( sum(num_features) + 2, output_dim ) self.bias = torch.nn.Parameter(torch.zeros((output_dim,))) self.feature_loc_offsets = torch.tensor( np.array((0, *np.cumsum(num_features)[:-1])), dtype=torch.long ) # print(self.feature_loc_offsets)
[docs] def forward(self, x): adjusted_x = x + self.feature_loc_offsets return torch.sum(self.weights_embed(adjusted_x), dim=1) + self.bias