Model
Models are almost always derived from the base class of torch.nn.Module.
Common layers
Let us inspect some common model layers.
Common config:
Linear
linear = nn.Linear(
in_features=num_features,
out_features=num_classes,
)
# def forward(self, input: Tensor) -> Tensor:
# return F.linear(input, self.weight, self.bias)
input_linear = torch.randn(num_samples, num_features)
output_linear = linear(input_linear)
print(f"{input_linear.shape=}")
print(f"{output_linear.shape=}")
# input_linear.shape=torch.Size([4, 32]) -> [num_samples, num_features]
# output_linear.shape=torch.Size([4, 10]) -> [num_samples, num_classes]
Conv1d
conv1d = nn.Conv1d(
in_channels=num_features,
out_channels=num_classes,
kernel_size=3,
)
# def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
# if self.padding_mode != "zeros":
# return F.conv1d(
# F.pad(
# input, self._reversed_padding_repeated_twice, mode=self.padding_mode
# ),
# weight,
# bias,
# self.stride,
# _single(0),
# self.dilation,
# self.groups,
# )
#
# return F.conv1d(
# input, weight, bias, self.stride, self.padding, self.dilation, self.groups
# )
#
# def forward(self, input: Tensor) -> Tensor:
# return self._conv_forward(input, self.weight, self.bias)
input_conv1d = torch.randn(num_samples, num_features, num_classes)
output_conv1d = conv1d(input_conv1d)
print(f"{input_conv1d.shape=}")
print(f"{output_conv1d.shape=}")
# input_conv1d.shape=torch.Size([4, 32, 10])
# output_conv1d.shape=torch.Size([4, 10, 8])
Conv2d
conv2d = nn.Conv2d(
in_channels=num_features,
out_channels=num_classes,
kernel_size=3,
)
# def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
# if self.padding_mode != "zeros":
# return F.conv2d(
# F.pad(
# input, self._reversed_padding_repeated_twice, mode=self.padding_mode
# ),
# weight,
# bias,
# self.stride,
# _pair(0),
# self.dilation,
# self.groups,
# )
#
# return F.conv2d(
# input, weight, bias, self.stride, self.padding, self.dilation, self.groups
# )
#
# def forward(self, input: Tensor) -> Tensor:
# return self._conv_forward(input, self.weight, self.bias)
input_conv2d = torch.randn(num_samples, num_features, num_classes, num_classes)
output_conv2d = conv2d(input_conv2d)
print(f"{input_conv2d.shape=}")
print(f"{output_conv2d.shape=}")
# input_conv2d.shape=torch.Size([4, 32, 10, 10])
# output_conv2d.shape=torch.Size([4, 10, 8, 8])
Dropout
dropout = nn.Dropout(p=0.5)
# def forward(self, input: Tensor) -> Tensor:
# return F.dropout(input, self.p, self.training, self.inplace)
input_dropout = torch.randn(num_samples, num_features, num_classes)
output_dropout = dropout(input_dropout)
print(f"{input_dropout.shape=}")
print(f"{output_dropout.shape=}")
# input_dropout.shape=torch.Size([4, 32, 10])
# output_dropout.shape=torch.Size([4, 32, 10])
BatchNorm
batchnorm = nn.BatchNorm1d(num_features=num_features)
# class BatchNorm1d(_BatchNorm):
# ...
#
# class _BatchNorm(_NormBase):
#
# def forward(self, input: Tensor) -> Tensor:
# self._check_input_dim(input)
#
# # exponential_average_factor is set to self.momentum
# # (when it is available) only so that it gets updated
# # in ONNX graph when this node is exported to ONNX.
# if self.momentum is None:
# exponential_average_factor = 0.0
# else:
# exponential_average_factor = self.momentum
#
# if self.training and self.track_running_stats:
# # TODO: if statement only here to tell the jit to skip emitting this when it is None
# if self.num_batches_tracked is not None: # type: ignore[has-type]
# self.num_batches_tracked.add_(1) # type: ignore[has-type]
# if self.momentum is None: # use cumulative moving average
# exponential_average_factor = 1.0 / float(self.num_batches_tracked)
# else: # use exponential moving average
# exponential_average_factor = self.momentum
#
# r"""
# Decide whether the mini-batch stats should be used for normalization rather than the buffers.
# Mini-batch stats are used in training mode, and in eval mode when buffers are None.
# """
# if self.training:
# bn_training = True
# else:
# bn_training = (self.running_mean is None) and (self.running_var is None)
#
# r"""
# Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
# passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
# used for normalization (i.e. in eval mode when buffers are not None).
# """
# return F.batch_norm(
# input,
# # If buffers are not to be tracked, ensure that they won't be updated
# (
# self.running_mean
# if not self.training or self.track_running_stats
# else None
# ),
# self.running_var if not self.training or self.track_running_stats else None,
# self.weight,
# self.bias,
# bn_training,
# exponential_average_factor,
# self.eps,
# )
input_batchnorm = torch.randn(num_samples, num_features, num_classes)
output_batchnorm = batchnorm(input_batchnorm)
print(f"{input_batchnorm.shape=}")
print(f"{output_batchnorm.shape=}")
# input_batchnorm.shape=torch.Size([4, 32, 10])
# output_batchnorm.shape=torch.Size([4, 32, 10])
MNIST
The MNIST example:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
# Random initialization
model = Net().to(device)
# When training (for example, consider behaviour of dropout)
model.train()
# When testing
model.eval()
# Save checkpoint
torch.save(model.state_dict(), "mnist_cnn.pt")
# Load checkpoint TODO