Help me understand how minibatch works.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(50, 1, kernel_size=(1, 32), stride=(1, 1))
self.bn1 = nn.BatchNorm2d(50, affine=True)
self.fc1 = nn.Linear(50,64, 32)
self.bn2 = nn.BatchNorm2d(50, affine=True)
self.fc2 = nn.Linear(50,32, 16)
self.bn3 = nn.BatchNorm2d(50, affine=True)
self.fc3 = nn.Linear(50,16, 1)
self.tan = nn.Hardtanh()
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.fc1(out)
out = self.bn2(out)
out = self.fc2(out)
out = self.bn3(out)
out = self.fc3(out)
return self.tan(out)
net = Net()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
st_1 = np.interp(np.loadtxt('1d.txt',delimiter=';'), [0,10], [-1,1])
st_2 = np.loadtxt('target_1d.txt',delimiter=';')
st_3 = np.loadtxt('My1d.txt',delimiter=';')
len_st = len(st_2)
batch = 50
b = (len_st-wn1)//batch
len_batch = b*batch
I create a minibatch on the fly in a loop.
for epoch in range(epochs):
for wn_start in range(0,len_batch,batch): # step - batch
wn_tick = wn_start + wn1
wn_all = []
los_l = []
for b_iter in range(batch): # create minibatch
wn_all = wn_all + [st_1[wn_start+b_iter:wn_tick+b_iter,:]]
los_l = los_l + [st_2[wn_tick-1]]
wn_all = torch.as_tensor(wn_all, dtype=torch.float32)
wn_all = wn_all.unsqueeze(0)
wn_all = torch.transpose(wn_all,2,3) #([1, 50, 32, 64]) -> ([1, 50, 64, 32])
wn_all = torch.transpose(wn_all,0,1) #([1, 50, 32, 64]) -> ([50, 1, 64, 32])
los_l = torch.Tensor([los_l]).unsqueeze(0).unsqueeze(0)
los_l = torch.transpose(los_l,0,3) #([1, 1, 50, 1]) -> ([50, 1, 1, 1])
outputs = net(wn_all)
loss1 = criterion(los_l, outputs[0,0,0,0])
optimizer.zero_grad() #Š¾Š±Š½ŃŠ»ŠµŠ½ŠøŠµ Š³ŃŠ°Š“ŠøŠµŠ½ŃŠ°
loss1.backward()
optimizer.step()
Look at my code, itās probably not perfect)), but it still doesnāt work.
My code worked when I did not use batch, but the training went on for a very long time, as there is a lot of data. I decided to use minibatch to wrap data into the network and process them in parallel.