ResNet modifications to output shape [batch, timestep, num_outputs]

Hi,

I’m looking to use ResNet for a multivariate regression problem. In itself this seems fairly straight forward as the num_classes argument just needs to be set to the number of output variables to be predicted.

I, however, am using audio data and want to make a prediction for each timestep/frame, so my output data needs to be of shape [batch, timestep, num_outputs] as opposed to ResNet’s default output of [batch, num_outputs].

I’m using the provided example code from the PyTorch docs here

I’ve adapted some of the code outlined below and state at the start of each code block the original defaults and what I’ve changed them to. I wonder if anyone can see a more sensible/less disruptive way of achieving my desired output shape.

I’ve changed padding from dilation → ‘same’

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding='same', groups=groups, bias=False, dilation=dilation)

I’ve changed the following:
all strides from 2 → 1
nn.Maxpool2D kernal size from 3 → (1,3)
fc input size from 512*block.expansion ----> whatever I need it to be

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding='same',
                               bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=(1,3), stride=None, padding=0)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=1,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(21504, num_classes) # input was originally 512*block.expansion

Changes made:
Permuted the tensor after self.layer4() to ensure correct dimension ordering for output
Took out pooling before self.fc()
Flatten starting from dim=2 to keep the timestep dimension intact.

 def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x) #(batch, channel, time, freq)

        x = torch.permute(x,(0,2,1,3)) # (batch,time,filt_channel,freq)

        # x = self.avgpool(x)

        x = torch.flatten(x, start_dim=2)
        x = self.fc(x)

        return x