Please help!Train module with gpu is much slow than cpu, Because of low GPU usage

low GPU usage

Please help, I am new in pytorch. I have saw all the topic about it, but there is no solution for me. when I train model with gpu, it just got 10% usage. So it is much slower than trainning model with cpu. The project and data are on SSD.I run it with Pytorch.There are my code:

from torchvision import datasets
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class One_hot(object):
    def __init__(self, output_shape):
        assert isinstance(output_shape, (int, tuple))
        self.output_shape = output_shape

    def __call__(self, label):
        a = np.zeros(shape=self.output_shape)
        a[label] = 1
        target = a
        return target


class Module(nn.Module):
    def __init__(self):
        super(Module, self).__init__()
        memory1 = Variable(torch.ones(32, 32), requires_grad=True)
        memory1 = memory1.to('cuda')
        self.memory1 = lambda x: x.mul(memory1)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        memory2 = Variable(torch.ones(14, 14), requires_grad=True)
        memory2 = memory2.to('cuda')
        self.memory2 = lambda x: x.mul(memory2)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.memory1(x)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.memory2(x)
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def img_show(img):
    img = img.numpy()
    img = img / 2 + 0.5
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()


if __name__ == '__main__':

    transform = transforms.Compose(
        [transforms.Pad(2, fill=0, padding_mode='constant'), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    train_set = datasets.MNIST(root='data', train=True, download=True, transform=transform)
    test_set = datasets.MNIST(root='data', train=False, download=True, transform=transform)
    train_load = DataLoader(dataset=train_set, batch_size=5, shuffle=True, num_workers=4, pin_memory=True)
    test_load = DataLoader(dataset=test_set, batch_size=5, shuffle=False, num_workers=4, pin_memory=True)
    dataiter = iter(train_load)
    data = next(dataiter)

    device = torch.device("cuda:0")
    module = Module()
    module = module.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(module.parameters(), lr=0.001, momentum=0.9)
    for i, data in enumerate(train_load, 0):
        running_loss = 0.0
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        optimizer.zero_grad()
        outputs = module(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print('[%5d] loss: %.3f' %
                  (i + 1, running_loss / 2000))
            running_loss = 0.0
    print("finish")
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_load:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = module(images)
            _, predicted = torch.m![0KQ%5BBTH77N%7DJ1LU9%5DU~0H02|690x433](upload://9evuM7ceITCHCkFFBbPN15qCy2N.png) ax(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))

Please help.Thanks