Skip to content

Model

Models are almost always derived from the base class of torch.nn.Module.

Common layers

Let us inspect some common model layers.

Common config:

num_samples = 4
num_features = 32
num_classes = 10

Linear

linear = nn.Linear(
    in_features=num_features,
    out_features=num_classes,
)

# def forward(self, input: Tensor) -> Tensor:
#     return F.linear(input, self.weight, self.bias)

input_linear = torch.randn(num_samples, num_features)
output_linear = linear(input_linear)
print(f"{input_linear.shape=}")
print(f"{output_linear.shape=}")
# input_linear.shape=torch.Size([4, 32]) -> [num_samples, num_features]
# output_linear.shape=torch.Size([4, 10]) -> [num_samples, num_classes]

Conv1d

conv1d = nn.Conv1d(
    in_channels=num_features,
    out_channels=num_classes,
    kernel_size=3,
)

# def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
#     if self.padding_mode != "zeros":
#         return F.conv1d(
#             F.pad(
#                 input, self._reversed_padding_repeated_twice, mode=self.padding_mode
#             ),
#             weight,
#             bias,
#             self.stride,
#             _single(0),
#             self.dilation,
#             self.groups,
#         )
#
#     return F.conv1d(
#         input, weight, bias, self.stride, self.padding, self.dilation, self.groups
#     )
#
# def forward(self, input: Tensor) -> Tensor:
#     return self._conv_forward(input, self.weight, self.bias)

input_conv1d = torch.randn(num_samples, num_features, num_classes)
output_conv1d = conv1d(input_conv1d)
print(f"{input_conv1d.shape=}")
print(f"{output_conv1d.shape=}")
# input_conv1d.shape=torch.Size([4, 32, 10])
# output_conv1d.shape=torch.Size([4, 10, 8])

Conv2d

conv2d = nn.Conv2d(
    in_channels=num_features,
    out_channels=num_classes,
    kernel_size=3,
)

# def _conv_forward(self, input: Tensor, weight: Tensor, bias: Tensor | None):
#     if self.padding_mode != "zeros":
#         return F.conv2d(
#             F.pad(
#                 input, self._reversed_padding_repeated_twice, mode=self.padding_mode
#             ),
#             weight,
#             bias,
#             self.stride,
#             _pair(0),
#             self.dilation,
#             self.groups,
#         )
#
#     return F.conv2d(
#         input, weight, bias, self.stride, self.padding, self.dilation, self.groups
#     )
#
# def forward(self, input: Tensor) -> Tensor:
#     return self._conv_forward(input, self.weight, self.bias)

input_conv2d = torch.randn(num_samples, num_features, num_classes, num_classes)
output_conv2d = conv2d(input_conv2d)
print(f"{input_conv2d.shape=}")
print(f"{output_conv2d.shape=}")
# input_conv2d.shape=torch.Size([4, 32, 10, 10])
# output_conv2d.shape=torch.Size([4, 10, 8, 8])

Dropout

dropout = nn.Dropout(p=0.5)

# def forward(self, input: Tensor) -> Tensor:
#     return F.dropout(input, self.p, self.training, self.inplace)

input_dropout = torch.randn(num_samples, num_features, num_classes)
output_dropout = dropout(input_dropout)
print(f"{input_dropout.shape=}")
print(f"{output_dropout.shape=}")
# input_dropout.shape=torch.Size([4, 32, 10])
# output_dropout.shape=torch.Size([4, 32, 10])

BatchNorm

batchnorm = nn.BatchNorm1d(num_features=num_features)

# class BatchNorm1d(_BatchNorm):
#     ...
#
# class _BatchNorm(_NormBase):
#
# def forward(self, input: Tensor) -> Tensor:
#     self._check_input_dim(input)
#
#     # exponential_average_factor is set to self.momentum
#     # (when it is available) only so that it gets updated
#     # in ONNX graph when this node is exported to ONNX.
#     if self.momentum is None:
#         exponential_average_factor = 0.0
#     else:
#         exponential_average_factor = self.momentum
#
#     if self.training and self.track_running_stats:
#         # TODO: if statement only here to tell the jit to skip emitting this when it is None
#         if self.num_batches_tracked is not None:  # type: ignore[has-type]
#             self.num_batches_tracked.add_(1)  # type: ignore[has-type]
#             if self.momentum is None:  # use cumulative moving average
#                 exponential_average_factor = 1.0 / float(self.num_batches_tracked)
#             else:  # use exponential moving average
#                 exponential_average_factor = self.momentum
#
#     r"""
#     Decide whether the mini-batch stats should be used for normalization rather than the buffers.
#     Mini-batch stats are used in training mode, and in eval mode when buffers are None.
#     """
#     if self.training:
#         bn_training = True
#     else:
#         bn_training = (self.running_mean is None) and (self.running_var is None)
#
#     r"""
#     Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
#     passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
#     used for normalization (i.e. in eval mode when buffers are not None).
#     """
#     return F.batch_norm(
#         input,
#         # If buffers are not to be tracked, ensure that they won't be updated
#         (
#             self.running_mean
#             if not self.training or self.track_running_stats
#             else None
#         ),
#         self.running_var if not self.training or self.track_running_stats else None,
#         self.weight,
#         self.bias,
#         bn_training,
#         exponential_average_factor,
#         self.eps,
#     )

input_batchnorm = torch.randn(num_samples, num_features, num_classes)
output_batchnorm = batchnorm(input_batchnorm)
print(f"{input_batchnorm.shape=}")
print(f"{output_batchnorm.shape=}")
# input_batchnorm.shape=torch.Size([4, 32, 10])
# output_batchnorm.shape=torch.Size([4, 32, 10])

MNIST

The MNIST example:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1    = nn.Conv2d(1, 32, 3, 1)
        self.conv2    = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1      = nn.Linear(9216, 128)
        self.fc2      = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

# Random initialization
model = Net().to(device)

# When training (for example, consider behaviour of dropout)
model.train()

# When testing
model.eval()

# Save checkpoint
torch.save(model.state_dict(), "mnist_cnn.pt")

# Load checkpoint TODO