Hi there! I’m trying to understand how nn.RNN(input_size, hidden_size, num_layers, batch_first=True) works for classification problems (not prediction). Let me explain:
I have my dataset which is compose of 1500 samples of vectors of size (1, 800). Each of the vectors of size (1,800) have one target associated, corresponding to the position were the sample was taken.
Using this I created the dataset with a custom class like this:
class WifiDataset(Dataset): def __init__(self, data, labels=None, transforms=None): self.X = data self.y = labels self.transforms = transforms def __len__(self): return len(np.asarray(self.X)) def __getitem__(self, i): data = self.X[i] if self.transforms: data = self.transforms(data) if self.y is not None: return data, self.y[i] else: return data
And then I create the dataloader as:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True)
So each data, target in train_loader is data.shape = (batch_size,806) and target.shape = (). Then I feed an RNN that is declared as:
class WifiRNN(nn.Module): def __init__(self, i_size, h_size, n_layers, num_classes): super(WifiRNN, self).__init__() self.input_size = i_size self.hidden_size = h_size self.num_layers = n_layers self.num_classes = num_classes self.wifi_rnn = nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True) self.out = nn.Linear(in_features=self.hidden_size * sequence_length, out_features=self.num_classes) def forward(self, x_in, h_state): r_out, _ = self.wifi_rnn(x_in, h_state) r_out = r_out.reshape(r_out.shape, -1) out = self.out(r_out) return out def init_hidden_state(self, b_size): h0 = torch.zeros(self.num_layers, b_size, self.hidden_size).to(device) return h0
Am I taking advantage of RNNs hidden state? If the answer is yes, am i doing it between batches or how does it work?
Do I need to feed the network with different timesteps to exploit the hidden state information?
(Same questions asked for LSTMs, because I know RNN’s have vanishing gradient problem)