#!/usr/bin/env python
# encoding: utf-8
import os
import torch
import requests
from torch.utils.data import DataLoader
from torchctr.progressbar import ProgressBar
from torchctr.dashboard import MetricLogger
[docs]class Trainer:
def __init__(self, model, dataset, param={}):
self.model = model
self.dataset = dataset
self.param = self.build_param(param)
self.trainer_setup = self.build_trainer()
[docs] def build_param(self, param):
print("| building parameters ...")
default_param = {
"batch_size": 128,
"num_workers": 4,
"device": "cpu",
"learning_rate": 0.01,
"weight_decay": 1e-6,
"epochs": 10,
"metrics": ["auc"],
}
for k, v in param.items():
if k in default_param:
default_param[k] = v
return default_param
[docs] def build_trainer(self):
print("| building trainer ...")
train_length = int(len(self.dataset) * 0.9)
valid_length = len(self.dataset) - train_length
train_dataset, valid_dataset = torch.utils.data.random_split(
self.dataset, (train_length, valid_length)
)
train_loader = DataLoader(
train_dataset,
batch_size=self.param.get("batch_size"),
num_workers=self.param.get("num_workers"),
)
valid_loader = DataLoader(
valid_dataset,
batch_size=self.param.get("batch_size"),
num_workers=self.param.get("num_workers"),
)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(
params=self.model.parameters(),
lr=self.param.get("learning_rate"),
weight_decay=self.param.get("weight_decay"),
)
logger = MetricLogger()
trainer_setup = {
"train_loader": train_loader,
"valid_loader": valid_loader,
"criterion": criterion,
"optimizer": optimizer,
"logger": logger
}
return trainer_setup
[docs] def train(self, dashboard_address=None):
dashboard_status = False
if dashboard_address is None:
print("| Didn't find dashboard")
else:
try:
requests.get(url="http://{}/ping".format(dashboard_address))
dashboard_status = True
except requests.exceptions.RequestException as e:
print("| ERROR in Dashboard connection: {}".format(e))
print("| Back to general training")
print("| Start training ...")
for e in range(self.param.get("epochs")):
self.train_step(e + 1)
self.valid_step()
if dashboard_status is True:
self.trainer_setup.get("logger").send(dashboard_address)
[docs] def train_step(self, epoch):
self.model.train()
progress_bar = ProgressBar(
self.trainer_setup.get("train_loader"),
self.param.get("metrics"),
desc="| Training {}/{}".format(epoch, self.param.get("epochs")),
)
for (fields, target) in progress_bar:
fields, target = (
fields.to(self.param.get("device")),
target.to(self.param.get("device")),
)
y = self.model(fields)
loss = self.trainer_setup.get("criterion")(y, target.float())
self.model.zero_grad()
loss.backward()
self.trainer_setup.get("optimizer").step()
progress_bar.eval(
target.tolist(), y.tolist(), {"loss": loss.item()}
)
progress_bar.summarize()
self.trainer_setup.get("logger").log(trace="train",
stats=progress_bar.summary)
[docs] def valid_step(self):
self.model.eval()
progress_bar = ProgressBar(
self.trainer_setup.get("valid_loader"),
self.param.get("metrics"),
desc="| Validating",
)
for (fields, target) in progress_bar:
fields, target = (
fields.to(self.param.get("device")),
target.to(self.param.get("device")),
)
y = self.model(fields)
loss = self.trainer_setup.get("criterion")(y, target.float())
progress_bar.eval(
target.tolist(), y.tolist(), {"loss": loss.item()}
)
progress_bar.summarize()
self.trainer_setup.get("logger").log(trace="validation",
stats=progress_bar.summary)
[docs] def save_model(self, file_fullpath):
path, filename = os.path.split(file_fullpath)
if os.path.exists(path) or path == "":
torch.save(self.model, file_fullpath)
else:
print("| Didn't find dir, so we will create it")
os.mkdir(path)
torch.save(self.model, file_fullpath)