import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import pdb
class RGBMaskEncoderCNN(nn.Module):
def __init__(self):
super(RGBMaskEncoderCNN, self).__init__()
self.alexnet = models.alexnet(pretrained=True)
new_classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:-1])
self.alexnet.classifier = new_classifier
# get the pre-trained weights of the first layer
pretrained_weights = self.alexnet.features[0].weight
new_features = nn.Sequential(*list(self.alexnet.features.children()))
new_features[0] = nn.Conv2d(4, 64, kernel_size=11, stride=4, padding=2)
# For M-channel weight should randomly initialized with Gaussian
new_features[0].weight.data.normal_(0, 0.001)
This file has been truncated. show original