from torch import nn
构建模型
class CIFAR_10(nn.Module):
def init(self,int_shape:int,heddin_shape:int,out_shape:int):
super().init()
self.Conv_Baock_1 = nn.Sequential(
nn.Conv2d(in_channels=int_shape,
out_channels=heddin_shape,
kernel_size=3,
stride=1,
padding=0),
nn.ReLU(),
nn.Conv2d(in_channels=heddin_shape,
out_channels=heddin_shape,
kernel_size=3,
stride=1,
padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=1,padding=1)
)
self.Conv_Baock_2 = nn.Sequential(
nn.Conv2d(in_channels=heddin_shape,
out_channels=heddin_shape,
kernel_size=3,
stride=1,
padding=0),
nn.ReLU(),
nn.Conv2d(in_channels=heddin_shape,
out_channels=heddin_shape,
kernel_size=3,
stride=1,
padding=0),
nn.ReLU(),
)
self.classLiner = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=heddin_shape,
out_features=out_shape)
)
def forward(self,x):
x = self.Conv_Baock_1(x)
print(x.shape)
x = self.Conv_Baock_2(x)
print(x.shape)
x = self.classLiner(x)
return x
mode_l = CIFAR_10(int_shape=10,heddin_shape=10,out_shape=10)
mode_l.state_dict
image_bath,label_bath = next(iter(train_data_lodael))
image_bath.shape,label_bath.shape
mode_l(image_bath)