import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class Generator(nn.Module):
‘’’
Generator
‘’’
def init(self, z=100, channels=3, d=128):
super().init()
self.z = z
self.channels = channels
self.deconv1_1 = nn.ConvTranspose2d(self.z, d2, 4, 1, 0)
self.deconv1_1_bn = nn.BatchNorm2d(d2)
self.deconv1_2 = nn.ConvTranspose2d(20, d2, 4, 1, 0)
self.deconv1_2_bn = nn.BatchNorm2d(d2)
self.deconv2 = nn.ConvTranspose2d(d4, d2, 4, 2, 1)
self.deconv2_bn = nn.BatchNorm2d(d2)
self.deconv3 = nn.ConvTranspose2d(d2, d, 4, 2, 1)
self.deconv3_bn = nn.BatchNorm2d(d)
self.deconv4 = nn.ConvTranspose2d(d, self.channels, 4, 2, 1)
def forward(self, input, label):
label_oh = Variable(label.data.new(label.size(0), 20).float().zero_())
label = label.unsqueeze(1)
label_oh.scatter_(1, label, 1)
label = label_oh.view(label_oh.size(0), label_oh.size(1), 1, 1)
input = input.view(input.size(0), -1, 1, 1)
x = F.relu(self.deconv1_1_bn(self.deconv1_1(input)))
y = F.relu(self.deconv1_2_bn(self.deconv1_2(label)))
x = torch.cat([x, y], 1)
x = F.relu(self.deconv2_bn(self.deconv2(x)))
x = F.relu(self.deconv3_bn(self.deconv3(x)))
features = x
x = torch.tanh(self.deconv4(x))
return x, features
this is a piece of the generator where for the cifar100 in this line because I entered 20 as a number of classes because I am considering training taking only the coarse_label which are 20 instead of the labels, in which case in this piece of code the number of classes is 100.
self.deconv1_2 = nn.ConvTranspose2d(100, d2, 4, 1, 0)
self.deconv1_2_bn = nn.BatchNorm2d(d2)
self.deconv2 = nn.ConvTranspose2d(d4, d2, 4, 2, 1)
self.deconv2_bn = nn.BatchNorm2d(d2)
self.deconv3 = nn.ConvTranspose2d(d2, d, 4, 2, 1)
self.deconv3_bn = nn.BatchNorm2d(d)
self.deconv4 = nn.ConvTranspose2d(d, self.channels, 4, 2, 1)
def forward(self, input, label):
label_oh = Variable(label.data.new(label.size(0), 100).float().zero_())
label = label.unsqueeze(1)
label_oh.scatter_(1, label, 1)
In output it gives me the following error and I can’t understand what I have to change.
label_oh.scatter_ (1, label, 1)
RuntimeError: Invalid index in scatter at C: \ w \ 1 \ s \ windows \ pytorch \ aten \ src \ TH / generic / THTensorEvenMoreMath.cpp: 151Mi by error