Skip to content

Training

transform    = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
dataset1     = datasets.MNIST('../data', train=True, download=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)

model = Net().to(device)
model.train()

optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    scheduler.step()


def train(args, model, device, train_loader, optimizer, epoch):

    # Crucial (consider behaviour of dropout)!!
    model.train()

    # Get batches of samples and ground truth
    for batch_idx, (data, target) in enumerate(train_loader):

        # Put on gpu if needed
        data, target = data.to(device), target.to(device)

        # Reset previous gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Compute loss
        loss = F.nll_loss(output, target)

        # Propogate gradients
        loss.backward()
        optimizer.step()

        # Log metrics
        if batch_idx % args.log_interval == 0:
            print(f"Loss: {loss}")