i am new to all deep learning, so i am sorry if my question does not make sense, i have a dataset and i would like to train it on a module, how to know the data structure for Example, mnist dataset. input data structure normally will be : [batch_size, channel, image_width, image_height] output : [batch_size, 10] since we have 10 classes. how to know this info for another dataset ?
The shape of your data is usually defined by the layers, loss functions, etc. you are using.
For CNNs you most likely will use
nn.Conv2d layers, which expect the data to have the shape
[batch_size, channels, height, width].
The output of your model should match the input of the loss function. For a multi-class classification use case you will most likely use
nn.NLLLoss, which both expect the prediction to have the shape
thank you @ptrblck , but actually my question is about the dataset before going through any model, like how to know the information of a dataset if the official website of this data does not provided any information about it
I will explain you by giving a simple example.
class CNN(nn.Module): def __init__(self, out_1=13, out_2=32): super(CNN, self).__init__() self.cnn1 = nn.Conv2d(in_channels=3, out_channels=out_1, kernel_size=3, padding=1) self.relu1 = nn.ReLU() self.maxpool1 = nn.MaxPool2d(kernel_size=2) self.cnn2 = nn.Conv2d(in_channels=out_1, out_channels=out_2, kernel_size=5, stride=1, padding=0) self.relu2 = nn.ReLU() self.maxpool2 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(out_2 * 23 * 23, 2) def forward(self, x): out = self.cnn1(x) out = self.relu1(out) out = self.maxpool1(out) out = self.cnn2(out) out = self.relu2(out) out = self.maxpool2(out) ** print(out.shape) out = out.view(out.size(0), -1) out = self.fc1(out) return out def activations(self, x): z1 = self.cnn1(x) a1 = self.relu1(z1) out = self.maxpool1(a1) z2 = self.cnn2(out) a2 = self.relu2(z2) out = self.maxpool2(a2) out = out.view(out.size(0),-1) return z1, a1, z2, a2, out
Suppose this is my model.
Now I would like to know what kind of data structure I am getting then I simply include print(out.shape) beneath the line in the code snippet in which I have included two asterisks. As illustrated in the snippet
To show you the example. See the below image
Hope I have answered your question