Convert a code from tensorflow/keras to pytorch frame work

I want to change the code in this link

into pytorch framework
I start working on this but I don’t know how I will convert lmdb datasets to pytorch tensor without error to start creating (new model and train the datasets) and also using pre-trained model that used in this link

this is my code

from google.colab import drive

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from import DataLoader
from import random_split
%matplotlib inline

from future import print_function
import os
import cv2
import json
import lmdb
import torch
import numpy as np
from matplotlib import pyplot

class USCISI_CMD_API( object ) :
“”" Simple API for reading the USCISI CMD dataset

This API simply loads and parses CMD samples from LMDB
# Example:
    # get the LMDB file path 
    lmdb_dir = os.path.dirname( os.path.realpath(__file__) )
    # create dataset instance
    dataset = USCISI_CMD_API( lmdb_dir=lmdb_dir, 
                              sample_file=os.path.join( lmdb_dir, 'samples.keys'),
                              differentiate_target=True )
    # retrieve the first 24 samples in the dataset
    samples = dataset( range(24) )
    # visualize these samples
    dataset.visualize_samples( samples )
    # retrieve 24 random samples in the dataset
    samples = dataset( [None]*24 )
    # visualize these samples
    dataset.visualize_samples( samples )

    # get the exact 50th sample in the dataset
    sample = dataset[50]
    # visualize these samples
    dataset.visualize_samples( [sample] )
# Arguments:
    lmdb_dir = file path to the dataset LMDB
    sample_file = file path ot the sample list, e.g. samples.keys
    differentiate_target = bool, whether or not generate 3-class target map

# Note:
    1. samples, i.e. the output of "get_samples" or "__call__", is a list of samples
    however, the dimension of each sample may or may not the same
    2. CMD samples are generated upon
       - MIT SUN2012 dataset []
       - MS COCO dataset []
    3. detailed synthesis process can be found in paper

# Citation:
    Yue Wu "BusterNet: Detecting Image Copy-Move ForgeryWith Source/Target Localization".  
    In: European Conference on Computer Vision (ECCV). Springer. 2018.
# Contact:
    Dr. Yue Wu
def __init__( self, lmdb_dir, sample_file,train_file,valid_file,test_file, differentiate_target = True ) :
    assert os.path.isdir(lmdb_dir)
    self.lmdb_dir = lmdb_dir
    assert os.path.isfile(sample_file)
    self.sample_keys = self._load_sample_keys(sample_file)
    #add train,valid and test files and load the keys of them
    assert os.path.isfile(train_file)
    self.train_keys = self._load_train_keys(train_file)
    assert os.path.isfile(valid_file)
    self.valid_keys = self._load_valid_keys(valid_file)
    assert os.path.isfile(test_file)
    self.test_keys = self._load_test_keys(test_file)
    self.differentiate_target = differentiate_target
    print("INFO: successfully load USC-ISI CMD LMDB with {} keys".format( self.nb_samples ) )
    print("INFO: It have {} train keys".format( len( self.train_keys ) ) )
    print("INFO: It have {} validation keys".format( len( self.valid_keys ) ) )
    print("INFO: It have {} test keys".format( len( self.test_keys ) ) )
def __len__(self):
    return len( self.sample_keys )
def nb_samples( self ) :
    return len( self.sample_keys )
def _load_sample_keys( self, sample_file ) :
    '''Load sample keys from a given sample file
        sample_file = str, path to sample key file
        keys = list of str, each element is a valid key in LMDB
    with open( sample_file, 'r' ) as IN :
        keys = [ line.strip() for line in IN.readlines() ]
    return keys
#new methods
def _load_train_keys( self, train_file ) :
    '''Load sample keys from a given sample file
        sample_file = str, path to sample key file
        keys = list of str, each element is a valid key in LMDB
    with open( train_file, 'r' ) as IN :
        train_keys = [ line.strip() for line in IN.readlines() ]
    return train_keys
def _load_valid_keys( self, valid_file ) :
    '''Load sample keys from a given sample file
        sample_file = str, path to sample key file
        keys = list of str, each element is a valid key in LMDB
    with open( valid_file, 'r' ) as IN :
        valid_keys = [ line.strip() for line in IN.readlines() ]
    return valid_keys
def _load_test_keys( self, test_file ) :
    '''Load sample keys from a given sample file
        sample_file = str, path to sample key file
        keys = list of str, each element is a valid key in LMDB
    with open( test_file, 'r' ) as IN :
        test_keys = [ line.strip() for line in IN.readlines() ]
    return test_keys
def _get_image_from_lut( self, lut ) :
    '''Decode image array from LMDB lut
        lut = dict, raw decoded lut retrieved from LMDB
        image = np.ndarray, dtype='uint8'
    image_jpeg_buffer = lut['image_jpeg_buffer']
    image = cv2.imdecode( np.array(image_jpeg_buffer).astype('uint8').reshape([-1,1]), 1 )
    image = cv2.resize(image, (128,128), interpolation = cv2.INTER_AREA)
    return image
def _get_mask_from_lut( self, lut ) :
    '''Decode copy-move mask from LMDB lut
        lut = dict, raw decoded lut retrieved from LMDB
        cmd_mask = np.ndarray, dtype='float32'
                   shape of HxWx1, if differentiate_target=False
                   shape of HxWx3, if differentiate target=True
        cmd_mask is encoded in the one-hot style, if differentiate target=True.
        color channel, R, G, and B stand for TARGET, SOURCE, and BACKGROUND classes
    def reconstruct( cnts, h, w, val=1 ) :
        rst = np.zeros([h,w], dtype='uint8')
        cv2.fillPoly( rst, cnts, val )
        return rst 
    h, w = lut['image_height'], lut['image_width']
    src_cnts = [ np.array(cnts).reshape([-1,1,2]) for cnts in lut['source_contour'] ]
    src_mask = reconstruct( src_cnts, h, w, val = 1 )
    tgt_cnts = [ np.array(cnts).reshape([-1,1,2]) for cnts in lut['target_contour'] ]
    tgt_mask = reconstruct( tgt_cnts, h, w, val = 1 )
    if ( self.differentiate_target ) :
        # 3-class target
        background = np.ones([h,w]).astype('uint8') - np.maximum(src_mask, tgt_mask)
        cmd_mask = np.dstack( [tgt_mask, src_mask, background ] ).astype(np.float32)
    else :
        # 2-class target
        cmd_mask = np.maximum(src_mask, tgt_mask).astype(np.float32)
    cmd_mask = cv2.resize(cmd_mask, (128,128), interpolation = cv2.INTER_AREA)
    return cmd_mask
def _get_transmat_from_lut( self, lut ) :
    '''Decode transform matrix between SOURCE and TARGET
        lut = dict, raw decoded lut retrieved from LMDB
        trans_mat = np.ndarray, dtype='float32', size of 3x3
    trans_mat = lut['transform_matrix']
    return np.array(trans_mat).reshape([3,3])
def _decode_lut_str( self, lut_str ) :
    '''Decode a raw LMDB lut
        lut_str = str, raw string retrieved from LMDB
        image = np.ndarray, dtype='uint8', cmd image
        cmd_mask = np.ndarray, dtype='float32', cmd mask
        trans_mat = np.ndarray, dtype='float32', cmd transform matrix
    # 1. get raw lut
    lut = json.loads(lut_str)
    # 2. reconstruct image
    image = self._get_image_from_lut(lut)
    # 3. reconstruct copy-move masks
    cmd_mask = self._get_mask_from_lut(lut)
    # 4. get transform matrix if necessary
    trans_mat = self._get_transmat_from_lut(lut)
    return ( image, cmd_mask, trans_mat )
def get_one_sample( self, key = None ) :
    '''Get a (random) sample from given key
        key = str, a sample key or None, if None then use random key
        sample = tuple of (image, cmd_mask, trans_mat)
    return self.get_samples([key])[0]
def get_samples( self, key_list ) :
    '''Get samples according to a given key list
        key_list = list, each element is a LMDB key or idx
        sample_list = list, each element is a tuple of (image, cmd_mask, trans_mat)
    env = self.lmdb_dir )
    sample_list = []
    with env.begin( write=False ) as txn :
        for key in key_list :
            if not isinstance( key, str ) and isinstance( key, int ):
                idx = key % self.nb_samples
                key = self.sample_keys[idx]
            elif isinstance( key, str ) :
            else :
                key = np.random.choice(self.sample_keys, 1)[0]
                print("INFO: use random key", key)
            lut_str = txn.get( key.encode(encoding='utf-8', errors='strict') )
            sample = self._decode_lut_str( lut_str )
            sample_list.append( sample )
    return sample_list
def visualize_samples( self, sample_list ) :
    '''Visualize a list of samples
    for image, cmd_mask, trans_mat in sample_list :
        pyplot.imshow( image )
        pyplot.imshow( cmd_mask )
def pytorch_tensor_samples( self, sample_list ) :
    '''Visualize a list of samples
    All_samples = {}
    for image, cmd_mask, trans_mat in sample_list :
        pytorchTensorSample = torch.tensor([image,cmd_mask])
        All_samples[i] =[pytorchTensorSample,pytorchTensorTrans_mat]
        i = i + 1
    return All_samples
def __call__( self, key_list ) :
    return self.get_samples( key_list )
def __getitem__( self, key_idx ) :
    return self.get_one_sample( key=key_idx )

import sys
lmdb_dir = ‘/content/drive/MyDrive/Full dataset/USCISI-CMFD’
dataset = USCISI_CMD_API( lmdb_dir=lmdb_dir,
sample_file=os.path.join( lmdb_dir, ‘samples.keys’),
train_file=os.path.join( lmdb_dir, ‘train.keys’),
valid_file=os.path.join( lmdb_dir, ‘valid.keys’),
test_file=os.path.join( lmdb_dir, ‘test.keys’ ),
differentiate_target=True )

def simple_cmfd_decoder( busterNetModel, rgb ) :
“”“A simple BusterNet CMFD decoder
# 1. expand an image to a single sample batch
single_sample_batch = np.expand_dims( rgb, axis=0 )
# 2. perform busterNet CMFD
pred = busterNetModel.predict(single_sample_batch)[0]
return pred

def visualize_result( rgb, gt=None, pred=None, figsize=(12,4), title=None ) :
“”“Visualize raw input, ground truth, and BusterNet result
pyplot.figure( figsize=figsize )
pyplot.imshow( rgb )
pyplot.title(‘input image’)
if gt is not None :
pyplot.title(‘ground truth’)
if pred is not None :
pyplot.title(‘busterNet pred’)
if title is not None :
pyplot.suptitle( title )

for k in range(10) :
rgb, gt, trans_mat = dataset.get_one_sample()
#pred = simple_cmfd_decoder( busterNetModel, rgb )
visualize_result( rgb, gt)

sample = dataset[100]

visualize these samples

dataset.visualize_samples( [sample] )

sample = dataset[0]
sample1 = dataset[1]

convert samples to pytorch tensor

pTensor= dataset.pytorch_tensor_samples([sample,sample1])


from import random_split

val_size = 10000
test_size = 10000
train_size = 80000

train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size])
len(train_ds), len(val_ds), len(test_ds)

from import DataLoader
train_loader = DataLoader(train_ds, batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size2)
val_loader = DataLoader(val_ds, batch_size

import matplotlib.pyplot as plt
from torchvision.utils import make_grid
for images,cmd_mask, trans_mat in train_loader:
print(‘images.shape:’, images.shape)
plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))

I don’t know if my starting is correct or not, also I found the following error

PyTorch uses the channels-first memory layout by default, so make sure your data is loaded in the shape [batch_size, channels, height, width]. Based on the print statement it seems the channel dimension is the last dim.

Thank you for your reply
I try to solve this problem using permute((0, 3, 1,2)) but I still have the error

my problem is the data in the form of lmdb, so I don’t know how I can deal with this form of data and modifying how to load it.

I still think that make_grid raises the error, since images is in a wrong shape.
Use permute on the images tensor and make sure the aforementioned shape is passed to make_grid.

I print the shape after use function (permute)

You need to reassign variable like this:
images = images.permute((0,3,1,2))