Source code for caver.process

import os
import torch
from sklearn.model_selection import train_test_split
from torch import nn, optim

from . import model
from .data import TextData
from .utils import make_batches, init_weight, zero_padding, transform2onehot
from .utils import get_top_label_with_logits, update_config, recall_at_k
from .utils import load_embedding
from .config import Config


[docs]class Trainer: """ :param model_name: name of model, case sensitive. :type model_name: str :param data_path: file path of data :type data_path: str You can pass your own config as parameters to replace default value in :class:`caver.config.Config`. GPU will be used if available. """ def __init__(self, model_name, data_path, **kwargs): self.config = update_config(Config(), **kwargs) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.data = TextData(data_path) assert hasattr(model, model_name) self.name = model_name self.model = getattr(model, model_name)( vocab_size=len(self.data.word2index), label_num=len(self.data.label2index) ).to(self.device) self.model.apply(init_weight) print(model) self.loss_func = self.config.loss_func \ if isinstance(self.config.loss_func, torch.nn.Module) \ else nn.BCELoss() self.optimizer = optim.RMSprop( filter(lambda p: p.requires_grad, self.model.parameters()) )
[docs] def init_model_embedding(self): """ Init embedding layer use pre-trained model. This will be used if :class:`caver.config.Config.embedding_file` is not None. """ print('Init embedding layer use pre-trained file') embedding_weight = load_embedding( self.config.embedding_file, self.config.embedding_dim, len(self.index2word), self.index2word ) embedding_weight = torch.from_numpy(embedding_weight).type(torch.FloatTensor) self.model.embedding.weight = nn.Parameter(embedding_weight) if not self.config.embedding_train: self.model.embedding.weight.requires_grad = False self.model.to(self.device)
def train(self): x, y = self.data.prepare() self.index2label = dict([(a, b) for (b, a) in self.data.label2index.items()]) self.index2word = dict([(a, b) for (b, a) in self.data.word2index.items()]) if self.config.embedding_file: self.init_model_embedding() train_x, val_x, train_y, val_y = train_test_split(x, y, test_size=self.config.valid) for i in range(self.config.epoches): print('Running epoch {}/{}'.format(i + 1, self.config.epoches)) self.train_step(train_x, train_y, val_x, val_y) self.model.save(os.path.join(self.config.save_path, self.name + '_{}.pth'.format(i + 1))) def train_step(self, train_x, train_y, val_x, val_y): batches = make_batches(len(train_x), self.config.batch_size) total_loss = 0.0 for i, (start, end) in enumerate(batches): input_data = torch.from_numpy( zero_padding(train_x[start:end], self.config.sentence_length) ).type(torch.long).to(self.device) input_label = torch.from_numpy( transform2onehot(train_y[start:end], len(self.index2label)) ).type(torch.float).to(self.device) self.model.train() self.model.zero_grad() logits = self.model(input_data) loss = self.loss_func(logits, input_label) total_loss += loss.item() loss.backward() self.optimizer.step() if i % (self.config.valid_interval - 1) == 0: self.eval_step(val_x, val_y) self.eval_step(val_x, val_y) def eval_step(self, x, y): batches = make_batches(len(x), self.config.batch_size) self.model.eval() total_loss, total_recall = 0.0, 0.0 total_preds = [] with torch.no_grad(): for start, end in batches: input_data = torch.from_numpy( zero_padding(x[start:end], self.config.sentence_length) ).type(torch.long).to(self.device) input_label = torch.from_numpy( transform2onehot(y[start:end], len(self.index2label)) ).type(torch.float).to(self.device) self.model.eval() logits = self.model(input_data) loss = self.loss_func(logits, input_label).item() total_loss += loss preds = [get_top_label_with_logits(p, self.index2label, self.config.recall_k) for p in logits.cpu().numpy()] total_preds.extend(preds) total_labels = [[self.index2label.get(i) for i in item] for item in y] total_text = [' '.join([self.index2word.get(i, '') for i in item]) for item in x] with open('{}_eval.txt'.format(self.name), 'w', encoding='utf-8') as f: for text, pred, label in zip(total_text, total_preds, total_labels): recall = recall_at_k(pred, label, self.config.recall_k) total_recall += recall f.write('{}\nLabel: {}\nPred: {}\n'.format(text, label, pred)) f.write('Recall@{}: {}\n'.format(self.config.recall_k, recall)) f.write('-' * 80 + '\n') print('[+] Total loss: {}\nAverage recall@{}: {}'.format( total_loss, self.config.recall_k, total_recall / len(y)) ) print('-' * 80 + '\n')