Training
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
dataset1 = datasets.MNIST('../data', train=True, download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
model = Net().to(device)
model.train()
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
scheduler.step()
def train(args, model, device, train_loader, optimizer, epoch):
# Crucial (consider behaviour of dropout)!!
model.train()
# Get batches of samples and ground truth
for batch_idx, (data, target) in enumerate(train_loader):
# Put on gpu if needed
data, target = data.to(device), target.to(device)
# Reset previous gradients
optimizer.zero_grad()
# Forward pass
output = model(data)
# Compute loss
loss = F.nll_loss(output, target)
# Propogate gradients
loss.backward()
optimizer.step()
# Log metrics
if batch_idx % args.log_interval == 0:
print(f"Loss: {loss}")