Hi, I use keras(backend is tensorflow) before, and I want to learn pytorch.
I load data set from keras, and run forward one time.
Then the memory is not enough, but I successful if I use keras.
How to solve this problem, and why? Here is my code.
################################################################################
##
##
## Package
import PIL.Image
import os
import pandas
import numpy
import torch
import torch.utils.data
import sklearn
import shutil
import timeit
from sklearn.model_selection import ParameterGrid
from PIL.Image import open as ReadImage
from sklearn.metrics import accuracy_score
from keras.datasets import cifar10
import torch
from torch.nn import *
from torch.optim import *
class I2C1FO(Module):
def __init__(self):
super(I2C1FO, self).__init__()
self.The1stConv = Conv2d(in_channels= 3, out_channels=32, kernel_size=(3,3), stride=(1,1), padding=1)
self.The2ndConv = Conv2d(in_channels=32, out_channels=96, kernel_size=(3,3), stride=(1,1), padding=1)
self.The1stFully = Linear(in_features=32*32*96, out_features=1000, bias=True)
self.The2ndFully = Linear(in_features= 1000, out_features= 10, bias=True)
def forward(self, ImageTerm):
N = ImageTerm.size()[0]
Pip = self.The1stConv(ImageTerm)
Pip = self.The2ndConv(Pip)
Pip = Pip.view(N, -1)
Pip = self.The1stFully(Pip)
Output = self.The2ndFully(Pip)
return(Output)
Train = {}
Valid = {}
(Train["Image"], Train["Label"]), (Valid["Image"], Valid["Label"]) = cifar10.load_data()
##
##
## Train
Reshape = (Train["Image"].shape[0], Train["Image"].shape[3]) + Train["Image"].shape[1:3]
Train["ImageTerm"] = torch.from_numpy(Train["Image"].reshape(Reshape) / 255).type(torch.FloatTensor)
Train["LabelCode"] = torch.from_numpy(Train["Label"]).type("torch.LongTensor").view(-1)
##
##
## Valid
Reshape = (Valid["Image"].shape[0], Valid["Image"].shape[3]) + Valid["Image"].shape[1:3]
Valid["ImageTerm"] = torch.from_numpy(Valid["Image"].reshape(Reshape) / 255).type(torch.FloatTensor)
Valid["LabelCode"] = torch.from_numpy(Valid["Label"]).type("torch.LongTensor").view(-1)
##
##
## Forward
with torch.no_grad():
Model = I2C1FO()
Model(Train["ImageTerm"])