Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.

Add modules for easily constructing residual networks #33

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

crowsonkb
Copy link
Contributor

Here are two modules that implement residual blocks in a nice object-oriented way. They subclass the container modules, and they accept modules as positional arguments similarly to nn.Sequential. They also pretty-print nicely, again due to subclassing the container modules. No more self-implemented convoluted logic in separately-defined __init__() and forward()

The ResidualBlock module just takes a sequence of modules and adds a 'shortcut connection' between its input and its output. In other words its final output is the sum of the output of its last module and its original input. I also provide a ResidualBlockWithShortcut module, which lets you customize the shortcut connection, for instance to make sure it is the same shape as the output of the main branch.

No more self-implemented convoluted easily-gotten-wrong network topology in separately-defined __init__() and forward() methods, when all you want is a standard ResNet! The code looks like this:

model = nn.Sequential(
    nn.Conv2d(1, 10, 1),
    ResidualBlock(
        nn.ReLU(),
        nn.Conv2d(10, 10, 3, padding=1),
        nn.ReLU(),
        nn.Conv2d(10, 10, 3, padding=1),
    ),
    nn.MaxPool2d(2),
    ResidualBlock(
        nn.ReLU(),
        nn.Conv2d(10, 10, 3, padding=1),
        nn.ReLU(),
        nn.Conv2d(10, 10, 3, padding=1),
    ),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(7*7*10, 10),
    nn.LogSoftmax(dim=-1),
)

and the model looks like this when printed:

Sequential(
  (0): Conv2d(1, 10, kernel_size=(1, 1), stride=(1, 1))
  (1): ResidualBlock(
    (0): ReLU()
    (1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): ResidualBlock(
    (0): ReLU()
    (1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU()
    (3): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Flatten()
  (6): Linear(in_features=490, out_features=10, bias=True)
  (7): LogSoftmax()
)

The references for residual blocks are: Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, "Deep Residual Learning for Image Recognition" (https://arxiv.org/abs/1512.03385), and "Identity Mappings in Deep Residual Networks" (https://arxiv.org/abs/1603.05027).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants