Source code for torchctr.layers.embedding

#!/usr/bin/env python
# encoding: utf-8

import torch
import numpy as np


[docs]class EmbeddingLayer(torch.nn.Module): ### same as linear layer but with embedding dims def __init__(self, num_features, embed_dim): super().__init__() self.weights_embed = torch.nn.Embedding( sum(num_features) + 1, embed_dim ) self.feature_loc_offsets = torch.tensor( np.array((0, *np.cumsum(num_features)[:-1])), dtype=torch.long ) torch.nn.init.xavier_uniform_(self.weights_embed.weight.data)
[docs] def forward(self, x): adjusted_x = x + self.feature_loc_offsets return self.weights_embed(adjusted_x)