En el tutorial básico de PyTorch hay un comentario en el código de una red de ejemplo que me tiene confundido:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension
# ...
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
# 16 * 6 * 6 is used here (after flattening)
No estoy seguro de dónde está el $6$ viene de. Más adelante en el tutorial se supone que las imágenes son $32\times 32$ Así que, ¿es simplemente
$$\left \lfloor{(\left \lfloor{(32-2)/2}\right \rfloor-2)/2}\right \rfloor = 6$$
o ¿hay algo que se me escapa sobre cómo hace PyTorch el relleno (o algún otro malentendido de principiante)?