Source code for caver.model.lstm

#!/usr/bin/env python
# encoding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import BaseModule

class InvalidInputException(Exception):

[docs]class LSTM(BaseModule): """ :param config: ConfigLSTM which contains LSTM configures :param int vocab_size: vocabulary number :param int label_num: all labels number :param device: cpu or gpu :param bool batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Simple LSTM model text -> embedding -> lstm -> mlp -> sigmoid """ # def __init__(self, hidden_dim=100, embedding_dim=100, vocab_size=1000, # label_num=100, device="cpu", layer_num=2, dropout=0.3, # batch_first=True, bidirectional=True): def __init__(self, config, vocab_size=1000, label_num=100, device="cpu", batch_first=True): super().__init__() # self.config = update_config(ConfigLSTM(), **kwargs) # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') ## starts with _ stands for static attr, otherwise nn layers self._layer_num = config.layer_num self._bidirectional = config.bidirectional self._device = device self._hidden_dim = config.hidden_dim self._vocab_size = vocab_size self._embedding_dim = config.embedding_dim self._label_num = label_num self._dropout = config.dropout self._batch_first = batch_first self.embedding = nn.Embedding(self._vocab_size, self._embedding_dim) self.dropout = nn.Dropout(self._dropout) self.lstm = nn.LSTM(self._embedding_dim, self._hidden_dim, self._layer_num, batch_first=self._batch_first, bidirectional=self._bidirectional, dropout=self._dropout) self.predictor = nn.Linear(self._hidden_dim*2 if self._bidirectional else self._hidden_dim*1, self._label_num) def init_hidden(self, batch_size): return ( torch.zeros( self.layer_num * (2 if self.bidirectional else 1), batch_size, self.hidden_dim ).to(self.device), torch.zeros( self.layer_num * (2 if self.bidirectional else 1), batch_size, self.hidden_dim ).to(self.device) ) def attention(self, rnn_out, state): state = state.unsqueeze(0) # print(state.shape) merged_state =[s for s in state], 1) # merged_state = merged_state.squeeze(0).unsqueeze(2) merged_state = merged_state.unsqueeze(2) #### [batch_size, sent len, hidden dim x num_directions ] x [batch_size, hidden dim x num_directions, 1] weights = torch.bmm(rnn_out, merged_state) #### bmm res = [batdh_size, sent len, 1] weights = F.softmax(weights.squeeze(2)).unsqueeze(2) #### weights = [batch_size, sent len, 1] #### transpose res = [batch_size, hidden_dim x num directions, sent len] #### bmm res = [batch_size, hidden_dim x num directions, 1] #### final res = [batch_size, hidden_dim x num_directions] return torch.bmm(torch.transpose(rnn_out, 1, 2), weights).squeeze(2) def forward(self, sequence): #### sentence = [batch_size , sent len] # batch_size = sequence.size(0) # hidden = self.init_hidden(batch_size) embedded = self.embedding(sequence) #### embedded = [batch_size , sent len , embedding dim] embedded = self.dropout(embedded) self.lstm.flatten_parameters() output, (hidden, cell) = self.lstm(embedded) #### output = [batch_size, sent len, hidden_dim x num_directions] #### hidden = [num layers x num directions, batch size, hiddim dim] #### cell = [num layers x num directions, batch size, hiddim dim] output_feature =[-2, :, :], hidden[-1, :, :]), dim=1) # print(hidden.shape) # hidden2 = hidden.squeeze(0) # print(output.shape) # output_feature = output[:,-1,:] # print(output[:,-1,:].shape) # output_a = output.permute(1,0,2)[-1,:,:] # output_feature = self.attention(output, output_feature) output_feature = self.dropout(output_feature) #### output_feature = [batch_size, hidden_dim x num_directions]# # print("output shape", output.shape) preds = self.predictor(output_feature) # print("lstm final output", preds.shape) return preds def predict(self, batch_sequence_text, device="cpu", top_k=5): batch_preds = self._get_model_output(batch_sequence_text=batch_sequence_text, vocab_dict=self.vocab, device=device) batch_top_k_value, batch_top_k_index = torch.topk(torch.sigmoid(batch_preds), k=top_k, dim=1) labels = self.predict_label(batch_top_k_index) return labels def predict_prob(self, batch_sequence_text, device="cpu"): batch_preds = self._get_model_output(batch_sequence_text=batch_sequence_text, vocab_dict=self.vocab, device=device) batch_prob = torch.softmax(batch_preds, dim=1) return batch_prob def _get_model_output(self, batch_sequence_text, vocab_dict, device="cpu"): """ do prediction for for tokenized text in batch way LSTM in normal way vocab_dict: {"word": 1, "<pad>": 0} """ batch_tokenized = [seq.split() for seq in batch_sequence_text] for sent in batch_tokenized: if len(sent) == 0: raise InvalidInputException("Invalid Input") batch_longest = max(map(len, batch_tokenized)) batch_padding_threshold = batch_longest # print(batch_longest) for sample in batch_tokenized: if len(sample) < batch_padding_threshold: sample += ["<pad>"] * (batch_padding_threshold - len(sample)) batch_indexed =[[vocab_dict[sample_token] for sample_token in sample] for sample in batch_tokenized] indexed = torch.LongTensor(batch_indexed).to(device) batch_preds = self.forward(indexed) return batch_preds