Handling non-image inputs of different shapes in a batch without zero-filling/zero-padding in which they are intermediate representation of 2D images with 3 channels created offline

I have (for the most part) gigapixel images that I have divided into 512x512 patches. Then I feed each 512x512 2D image with 3 channel into a ResNet18 frozen network for feature extraction and I end up with a 1D 512 tensor. Eventually, I concatenate all these 512x512 1D 512 tensors and I end up with Nx512 intermediate representation dimension where N is the number of patches in the gigapixel image.

Since my original gigapixel images are not all the same size and they range from 17x512 to 6000x512, I am using the following as a strategy in order to feed them to my model. However, my preference is to use a more standardize method as in PyTorch (in case of 2D images with 3 channel perhaps we could easily do torch transform – not here).

feature_path = 'features.pt'
features = torch.load(feature_path, map_location=lambda storage, loc: storage)
if features.shape[0] <= median_num_patches:
    a = torch.zeros((median_num_patches - features.shape[0], 512)) #zero padding to lenght median_num_patches
    embeddings = torch.cat((features, a), axis=0)
    sample['image'] = embeddings
else: 
    random_indices = torch.randint(features.shape[0], (median_num_patches, )) # max size: 6000 patches in an image
    sample['image'] = features[random_indices, :]
            

^ As mentioned earlier, the 2D intermediate representation (Nx512) is created in an offline process and saved in features.pt files.

The above solution, after finding what the median of size of 2D intermediate representations are based on number of patches in each gigapixel image, first checks to see if the size of current 2D intermediate representation in the batch is smaller that the median, and if so, it zero-fills that 2D intermediate representation to the size of median. And if the size of 2D intermediate representation in the batch is larger than median, it does sample median number of patches from that 2D representation.

Also, if I don’t use the median-based zero-filling or sampling approach above, and directly use inputs of different shapes, this is the error I get:


=>Epoches 1, learning rate = 0.0010000, previous best = 0.0000
training...

/home/jalal/research/venv/dpcc/lib/python3.8/site-packages/torch/optim/lr_scheduler.py:129: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "

<class 'dict'>
dict_keys(['image', 'label', 'id'])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [16], in <module>
     32 print(sample_batched.keys())
     33 # sample_batched['image'] is an array of tensors with len of batch_size
---> 34 feats = torch.stack(sample_batched['image']) 
     35 print("feature size shape: ", feats.shape)
     36 labels = torch.as_tensor(sample_batched['label']).cuda() 

RuntimeError: stack expects each tensor to be equal size, but got [411, 512] at entry 0 and [236, 512] at entry 1

Training code is:

# training, validation, and test phases

for epoch in range(num_epochs):

    train_loss = 0.
    total = 0.

    current_lr = optimizer.param_groups[0]['lr']
    print('\n=>Epoches %i, learning rate = %.7f, previous best = %.4f' %
          (epoch+1, current_lr, best_val_acc))
    
    train_epoch_loss = 0
    train_epoch_labels = []
    train_epoch_preds = []
    
    val_epoch_loss = 0
    val_epoch_labels = []
    val_epoch_preds = []
    val_epoch_labels_arr = []
    val_epoch_preds_arr = []
    

    epoch_loss = 0
    epoch_accuracy = 0
    
    if train:
        exp_lr_scheduler.step()
        print('training...')
        torch.autograd.set_detect_anomaly(True)
        for i_batch, sample_batched in enumerate(dataloader_train):  
            print(type(sample_batched))
            print(sample_batched.keys())
            # sample_batched['image'] is an array of tensors with len of batch_size
            feats = torch.stack(sample_batched['image']) 
            print("feature size shape: ", feats.shape)
            labels = torch.as_tensor(sample_batched['label']).cuda() 
            output = model(feats)
            loss = kornia.losses.focal_loss(output, labels, **kwargs)
            print('train loss is: ', loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            acc = (output.argmax(dim=1) == labels).float().mean()
            train_preds = output.argmax(dim=1)
            print('train preds are: ', train_preds)
            train_epoch_preds.extend(train_preds.cpu().numpy())
            train_epoch_labels.extend(labels.cpu().numpy())
            epoch_accuracy += acc / len(dataloader_train)
            epoch_loss += loss / len(dataloader_train)
            print('epoch accuracy: ', epoch_accuracy)
            
        train_epoch_accuracy = accuracy_score(train_epoch_labels, train_epoch_preds)
        print('train_epoch_accuracy: ', train_epoch_accuracy)

I am looking for a better solution than the current one. Perhaps something without sampling or zero-filling and without loss of data. Thanks for any possible lead.

Hi Mona!

As you’ve recognized, you pass a batch of samples into your model as
a single tensor (with a batch dimension), but you have to make all of the
samples have the same shape, because, as you know, and as the error
message confirms, pytorch does not support “ragged” tensors (that is,
tensors whose slices have differing shapes).

One approach that could reduce (and possibly eliminate, depending on the
details of your data) the amount you have to zero-pad and / or downsample
would be to group your samples into groups that have similar, or even
identical sizes, and make batches of those groups.

If a group of samples all has one size, then you can stack() them into
a batch tensor without any processing (or loss of data). If the sizes differ,
but are all pretty close to one another, you could zero-pad the samples
up to the largest size in the group, but you won’t have to zero-pad by
very much. (Or if you chose to downsample or crop, you wouldn’t have
to do so very much, so any loss of information would be reduced.)

Note, the shape of the samples in one batch doesn’t have to be the
same as the shape in a different batch, and the batches don’t have to
have the same number of elements in them. If it’s very important not
to zero-pad or downsample your samples, you could choose to build
batches only out of sample that start out with the same shape. If some
of your batches contain only one sample (because there is only one
sample with that particular shape), there is nothing logically wrong about
doing this, although small or single-sample batches could reduce the
efficiency with which you use your gpu (or cpu).

Best.

K. Frank