Torch multiprocessing: computation gets stalled on the thread

Hi everyone,

I am preparing some assessments for my deep learning lectures. Now I want to show students the gradient vanishing problem and how neural network activations such as ReLU solve this, plus also displaying how sparse representations are learned.

To do so what I do is implement a neural network and, after some training epochs, I launch a thread that starts printing and displaying some of these things, while the network keeps training on the main thread, since otherwise it will take ages to finish.

What I do is using the torch’s multiprocessing wrapper. However, I have realized that the thread gets exactly stalled when calling the torch.matmul forward computation that takes place into the nn.Linear layer. From what I have read it seems that since the main loop is updating the parameters of nn.Linear layer, the thread is blocked to use this resource.

However, as you’ll see in my code the thread uses a different network instance. When launching the thread, the model’s state dict is passed to the thread to instance the model.

Here is the code:

import torch
import torch.nn as nn
import torch.multiprocessing as multiprocessing
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from IPython.display import Markdown, display, Video
from io import BytesIO
import os
import time
import gzip
import copy


## ==================================
## MODEL IMPLEMENTATION 
## ==================================
def linear_link(x):
    return x
    
class FCLayer(nn.Module):
    def __init__(self, dim_in, dim_out, act):
        super().__init__()
        ## create parameters
        self.linear = nn.Linear(dim_in, dim_out)
        self.act = act
        self.traced_grads = {}
        
    def forward(self, x):
        return self.act(self.linear(x))

    def trace_gradients(self):
        for k, v in self.named_parameters():
            if k in self.traced_grads.keys():
                self.traced_grads[k].append( v.grad.data.clone() )
            else:
                self.traced_grads[k] = [v.grad.data.clone()]

    def reset_traced_grads(self):
        self.traced_grads = {}

    def get_traced_grads(self):
        return self.traced_grads

class FCDNN(nn.Module):
    def __init__(self, dim_in, dim_out, neurons_hidden : list, hidden_activations:list, link_function, loss_function):
        super().__init__()

        assert len(neurons_hidden) == len(hidden_activations), "List specifying hidden activations and number of hidden layers must coincide"

        module_list = nn.ModuleList([])

        # input layer hidden layers
        for num_neur, act in zip(neurons_hidden, hidden_activations):
            module_list.append(FCLayer(dim_in, num_neur, act))
            dim_in = num_neur

        # output layer
        o_layer = FCLayer(dim_in, dim_out, act = linear_link)
        module_list.append(o_layer)

        self.layers = module_list
       
        ## Loss and link function
        self.link = link_function
        self.loss = loss_function
        
    def forward(self,x, apply_link):
        for l in self.layers:
            x = l(x)
        y = x
        if apply_link:
            y = self.link(y)
        
        return y

    def compute_loss(self,t,y):
        return self.loss(y,t)

    def get_internal_representations(self, x, apply_link):
        internals = {}
        # exclude output layer projection since this method gets internal representations
        for i, l in enumerate(self.layers[:-1]):
            x = l(x)
            internals[f'layer {i+1}']= x
            
        return internals
    
    def trace_gradients(self):
        for l in self.layers:
            l.trace_gradients()

    def reset_traced_grads(self):
        for l in self.layers:
            l.reset_traced_grads()

    def get_traced_grads(self):
        traced_grads = {}
        for i,l in enumerate(self.layers):
            traced_grads[f'layer {i+1}'] = l.get_traced_grads()
        
        return traced_grads



## ==================================
## VISUALIZATION FUNCITONS 
## ==================================

def find_factors(N):
    sqrt_N = int(np.sqrt(N))  # Comenzar desde la raíz cuadrada
    for A in range(sqrt_N, 0, -1):
        if N % A == 0:  # Si A es factor de N
            B = N // A
            return A, B  # Retornar el par (A, B)

def compute_metric(dataloader, model):
    acc = 0.0
    for x,t in dataloader:
        ## move data to device and reshape
        x, t = x.to(device), t.to(device)
        x = x.view(-1,dim_in)

        ## forward no need to apply link since softmax is monotically increasing
        y = model(x, apply_link = False)

        ## compute test accuracy
        acc += (t == torch.argmax(y, dim = 1)).sum()
        
    return acc

def get_internal_representations(dataloader, model, N_samples):

    ## Get learnt internal representations
    internal_representations = {}
    counter_dict = {}
    for x,t in dataloader:
        ## move data to device and reshape
        x, t = x.to(device), t.to(device)
        x = x.view(-1,dim_in)

        _internal_representations = model.get_internal_representations(x = x, apply_link = True)
        
        for k, v in _internal_representations.items():
            if k in internal_representations.keys():
                internal_representations[k][counter_dict[k] :  counter_dict[k] + v.shape[0]] = v
                counter_dict[k] += v.shape[0]
            else:
                internal_mat = torch.zeros(N_samples, v.shape[-1], dtype = torch.float32)
                internal_mat[0:v.shape[0]] = v
                internal_representations[k] = internal_mat
                counter_dict[k] = v.shape[0]
                
    return internal_representations

def get_internal_gradients(model):
    internal_grads = {}
    for layer, parameter_grads in model.get_traced_grads().items():
        ## create space for new layer
        internal_grads[layer] = {}
        for k, v in parameter_grads.items():
            
            ## get total number of rows for memory allocation
            tot_rows = 0
            for _v in v:
                tot_rows += 1
            
            ## reset counter
            counter = 0
            for _v in v:
                _v = _v.view(-1)
                
                if k in internal_grads[layer].keys():
                    internal_grads[layer][k][counter] = _v 
                else:
                    internal_mat = torch.zeros(tot_rows, _v.shape[-1], dtype = torch.float32)
                    internal_mat[counter] = _v
                    internal_grads[layer][k] = internal_mat

                counter += 1
                
    return internal_grads

class VisualizeInternals:
    def __init__(self, model, args_instances):
        fig, ((ax11,ax12), (ax21,ax22)) = plt.subplots(2,2, figsize = (20,10))
        self.fig = fig
        self.axis = ((ax11,ax12), (ax21,ax22))
        self.model = FCDNN
        self.args_instances = args_instances

    def draw(self, dataloader, model_state_dict, N_maps_to_show: int, visualize_sparsity:bool, N_samples):
        ## load parameters into model
        model = self.model(**self.args_instances)
        model.load_state_dict(model_state_dict)

        ## grab figure
        fig = self.fig
        ((ax11,ax12), (ax21,ax22)) = self.axis

        ax11.cla()
        ax12.cla()
        ax21.cla()
        ax22.cla()

        internal_representations = get_internal_representations(dataloader, model, N_samples)
        internal_grads = get_internal_gradients(model)

        for i, ((layer_name,layer_activations_mat), (layer_name, layer_gradients_params)) in enumerate(zip(internal_representations.items(), internal_grads.items())):
            layer_activations_vec = layer_activations_mat.view(-1).numpy()
            layer_activations_mat = layer_activations_mat.clone().numpy()

            if visualize_sparsity:
                layer_activations_mat[layer_activations_mat > 0] = 1 

            ## Plot histogram
            a = ax11.hist(layer_activations_vec, bins = 40, density = True, edgecolor='black', color = f'C{i}',alpha=0.4, label = f'{layer_name}')

            ## Plot quantiles
            q5,q50,q95 = np.quantile(layer_activations_vec,[0.025,0.5,0.975]) 
            ax12.errorbar(0, q50, xerr = 0, yerr = [[q50-q5], [q95-q50]],fmt='o', color = f'C{i}', label = f'{layer_name}' )

            ## visualizing gradients at layer
            for k, layer_gradients_mat in layer_gradients_params.items():
                layer_gradients_vec = layer_gradients_mat.view(-1).numpy()
                a = ax22.hist(layer_gradients_vec, bins = 40, density = True, edgecolor='black', color = f'C{i}',alpha=0.4, label = f'{layer_name} {k}')
                
            ax11.set_title(f"Histogram of activations per layer at epoch {e}")
            ax11.set_xlabel("Activation value")
            ax11.legend()

            ax12.set_title("Median and quantiles over learning")
            ax12.set_xlabel("Epochs")
            ax12.legend()

            ax22.set_title(f"Histogram of gradients per layer at epoch {e}")
            ax22.set_xlabel("Gradient value")
            ax22.legend()

            ## visualizing feature map
            for j in range(N_maps_to_show):
                ax21.cla()
                feat_map = layer_activations_mat[j]
                A,B = find_factors(len(feat_map))
                feat_map = np.reshape(feat_map ,(A,B))

                ax21.imshow(feat_map, cmap = 'gray')
                ax21.set_title(f"Feature map: {layer_name} map:{j}")

                fig.canvas.draw()
                fig.canvas.flush_events()
                time.sleep(0.5)

            fig.canvas.draw()
            fig.canvas.flush_events()
            time.sleep(1)


## ================================
## MODEL TRAINING
## ===============================
device = 'cpu'

## ======================== ##
## Parallelize computations ##
## ======================== ##
thread_finished = None

## =========================== ##
## Data Pipeline configuration ##
## =========================== ##

## Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Train dataset
train_dataset = torchvision.datasets.MNIST(root='/tmp/data', train=True, transform=transform, download=True)

# Test dataset
test_dataset = torchvision.datasets.MNIST(root='/tmp/data', train=False, transform=transform, download=True)

# Train and test loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
train_loader_eval = torch.utils.data.DataLoader(train_dataset, batch_size=1000, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

## Data stats
N_training = len(train_dataset)
N_test = len(test_dataset)

## ============== ##
## Model Creation ##
## ============== ##
dim_in = 784
dim_out = 10
model = FCDNN(
              dim_in = dim_in, 
              dim_out = dim_out, 
              neurons_hidden = [32, 32], 
              hidden_activations = [torch.relu, torch.relu],
              link_function = torch.softmax, 
              loss_function = nn.CrossEntropyLoss(),
             )
model.to(device)

## ====================== ##
## Visualization Purposes ##
## ====================== ##
visualize_sparsity = True
visualizer = VisualizeInternals(FCDNN_PLOT, {
              'dim_in' : dim_in, 
              'dim_out' : dim_out, 
              'neurons_hidden' : [32, 32], 
              'hidden_activations' : [torch.relu, torch.relu],
              'link_function' : torch.softmax, 
              'loss_function' : nn.CrossEntropyLoss(),
}
)

## ================= ##
## Training Pipeline ##
## ================= ##
eval_each = 10
epochs = 100
lr = 0.1
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1, momentum = 0.9)

loss_epochs = []
train_acc_epochs = []
test_acc_epochs = []

for e in range(epochs):
    loss_acc = 0.0
    for batch_idx, (x,t) in enumerate(train_loader):
        
        ## move data to device and reshape
        x, t = x.to(device), t.to(device)
        x = x.view(-1,dim_in)

        ## forward
        y = model(x, apply_link = False)
        L = model.compute_loss(t,y)
        loss_acc += len(x)*L.item()
        
        ## backward
        L.backward()
        
        ## Optimizer step
        optimizer.step()

        ## Trace gradients before zero grad
        model.trace_gradients()
        
        ## Optimizer zero grad
        optimizer.zero_grad()
        
        ## loss per batch
        #print(f"On epoch {e} batch_idx {batch_idx + 1} got loss {L.item():.5f}", end = "\r")
        
    ## save loss per epoch for later display
    loss_epochs.append(loss_acc)
    
    ## Evaluate model and draw plots using threads to speed up computations.
    if e % eval_each == 0:
        with torch.no_grad():
            
            ## compute accuracies per sample
            train_acc = compute_metric(train_loader_eval, model)
            test_acc = compute_metric(test_loader, model)

            ## wait for thread to finish.
            if thread_finished is not None:
                try:
                    print("esperando al hilo que termine")
                    thread_finished.join()  
                except Exception as ex:
                    print(f"Error al obtener el resultado de la visualización: {ex}")

            model_state_dict = copy.deepcopy(model.state_dict())    
            thread_finished = multiprocessing.Process(
                                target=visualizer.draw, args=(train_loader_eval, model_state_dict, 5,  visualize_sparsity, N_training)
            )
            thread_finished.start()

            ## save train and test accuracies
            train_acc_epochs.append(train_acc)
            test_acc_epochs.append(test_acc)

        ## total loss on training set
        print(" "*200, end="\r")
        print(f"On epoch {e} got loss {loss_acc/N_training:.5f} with train accuracy {train_acc / N_training:.5f} and test accuracy {test_acc / N_test:.5f}")
     
    else:
        ## total loss on training set
        print(" "*200, end="\r")
        print(f"On epoch {e} got loss {loss_acc/N_training:.5f}")

Code gets stalled when launching the draw method from the Visualize class in a thread, more precisely in the draw method when calling function get_internal_representaitons. This function loops over the layer and saves the projections of the dataset into the hidden layers of the network

I have realized the code gets stalled in other parts. So for example if I comment out code regarding computations through the model, the code gets stalled when performing operation:

internal_mat = torch.zeros(N_samples, v.shape[-1], dtype = torch.float32)
      internal_mat[0:v.shape[0]] = v
internal_representations[k] = internal_mat
counter_dict[k] = v.shape[0]