DDP: Only one rank finishing while rest hang

Hi, I am attempting to do distributed training on a multi-gpu instance using pytorch DDP. My model is a PyG GNN trained on a heterogenous graph. For some reason only one of the ranks (3) completes the script while the rest hang and appear to timeout for some reason. Because I am using a PyG hetero data object I am not able to use the sampler and instead chunk the data across ranks for the dataloader. I initiate training with the following command:

torchrun --nproc_per_node=4 gnn_model.py

Here is the final output:

rank 2 epoch 1
train rank 2
batch 0/1052
batch 1000/1052
val rank 3
test rank 3
epoch 1 complete on rank 3
Memory Usage On Rank  3
| ID | GPU  | MEM |
-------------------
|  0 | 100% | 79% |
|  1 | 100% | 63% |
|  2 | 100% | 67% |
|  3 |  50% | 52% |
RAM Used (GB): 182.41599488
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806397 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806256 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806176 milliseconds before timing out.
Clean up run on rank: 3
Clean up completed on rank 3
ip-172-31-4-202:2515:2626 [3] NCCL INFO [Service thread] Connection closed by localRank 3
ip-172-31-4-202:2515:2515 [0] NCCL INFO comm 0x55986c074da0 rank 3 nranks 4 cudaDev 3 busId 1e0 - Abort COMPLETE
STOPPING INSTANCE

Below is my code:

#!/usr/bin/env python
# coding: utf-8

import resource
import boto3
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))

import wandb

wandb.login()

import torch
import time
import numpy as np
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.loader import NeighborLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from sklearn.metrics import roc_auc_score
#from torch.utils.tensorboard import SummaryWriter # Writer will output to ./runs/ directory by default.
from datetime import datetime
import gc
from GPUtil import showUtilization as gpu_usage
import torch.nn as nn
import random, requests, glob, time
from torch.utils.checkpoint import checkpoint
from sklearn.metrics import roc_auc_score

import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero, norm
from torch_geometric.nn.conv import GATv2Conv
from torch_geometric.transforms import ToUndirected, RandomLinkSplit, NormalizeFeatures, Compose
from torch.distributed.elastic.multiprocessing.errors import record 
import torch.distributed as dist

import os
import random
import psutil
import torch.multiprocessing as mp

os.environ['NCCL_DEBUG']='INFO'
# os.environ['NCCL_LL_THRESHOLD']=0

def stop_instance(instance_id):
    ec2 = boto3.resource('ec2', region_name='us-west-2') 
    instance = ec2.Instance(instance_id)
    response = instance.stop()

    return response

def get_instance_id():
    response = requests.get('http://169.254.169.254/latest/meta-data/instance-id')
    return response.text

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def free_gpu_cache(empty_cache=True, print_gpu=False, rank=0):

    if empty_cache:
        gc.collect()
        torch.cuda.empty_cache()

    if print_gpu:
        print("Memory Usage On Rank ", rank)
        gpu_usage()
        print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)

aws_dict = {
'heads': [1],
'hidden_channels': [128],
'output_channels': [32],
'decode_channels': [32],
'num_conv': [2],
'num_lin': [1]
}

num_neighbors = [10, 3]
disjoint_train_ratio = 0.5
epochs = 2 
dropout = 0.7
lr = 0.001
aggr = 'sum'
batch_size = 32

new_data = False
instance_stop = True

class CriterionParallel(nn.Module):
    def __init__(self, module):
        super(CriterionParallel, self).__init__()
        self.module = module
        if not isinstance(module, nn.Module):
            raise ValueError("module must be a nn.Module")
        if torch.cuda.device_count() == 0:
            raise ValueError("No GPUs available")
        self.device_ids = list(range(torch.cuda.device_count()))
        
    def forward(self, inputs, *targets):
        targets = tuple(t.to(inputs.device) for t in targets)
        return self.module(inputs, *targets)

def can_convert_to_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False

def load_checkpoint_model(checkpoint_path, hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, data, device, model_filename=None):
    if model_filename:
        checkpoint_epoch_path = os.path.join(checkpoint_path, model_filename)
        if not os.path.isfile(checkpoint_epoch_path):
            print(f"Error: File {model_filename} not found.")
            return None, None  # or handle this in another way suitable for your application
    else:
        list_of_checkpoint_files = glob.glob(os.path.join(checkpoint_path, '*'))
        # Make sure the file is a PyTorch file (.pt) and can extract an integer from filename.
        list_of_checkpoint_files = [f for f in list_of_checkpoint_files if f.endswith('.pt') and can_convert_to_int(f.split("_")[-1].split('.')[0])]
        if not list_of_checkpoint_files:
            print("Error: No valid checkpoint files found.")
            return None, None  # or handle this in another way suitable for your application

        checkpoint_epoch_number = max([int(file.split("_")[-1].split('.')[0]) for file in list_of_checkpoint_files])
        checkpoint_epoch_path = os.path.join(checkpoint_path, f'checkpoint_AWS_RUN_0_epoch_{checkpoint_epoch_number}.pt')
    
    # Instantiate the model with the correct parameters and load onto the device.
    resume_model = Model(hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, data).to(device)
    
    # Load the state dict onto the device.
    resume_model.load_state_dict(torch.load(checkpoint_epoch_path, map_location=device))
    
    return resume_model, checkpoint_epoch_number


def data_process(path, rank):

    # Load the hetero data object created in Data_formatting.ipynb
    print('loading data')
    data = torch.load(path, map_location=torch.device('cpu'))
    print('data loaded')
    free_gpu_cache(print_gpu=True, rank=rank)
    
    data = ToUndirected(merge=False)(data)
    print('data to undirected')

    #Perform a link-level split into training, validation, and test edges as well as node normalization.
    transform = Compose([NormalizeFeatures(['x','edge_attr']), RandomLinkSplit(
        num_val=0.2,
        num_test=0.1,
        split_labels=True,
        is_undirected=True,
        edge_types=[('node_a', 'edge_ab', 'node_b'), ('node_c', 'edge_cd', 'node_d')], 
        rev_edge_types=[('node_b', 'rev_edge_ab', 'node_a'), ('node_d', 'rev_edge_cd', 'node_c')],
        disjoint_train_ratio=disjoint_train_ratio
    )])

    train_data, val_data, test_data = transform(data)
    print('data transformed')

    del data

    free_gpu_cache(print_gpu=True, rank=rank)

    return train_data, val_data, test_data


def data_loaders(batch_size, train_data, val_data, test_data, num_neighbors, rank, world_size):

    train_idx = torch.chunk(torch.arange(train_data['node_a'].num_nodes, dtype=torch.long), world_size)[rank]
    val_idx = torch.chunk(torch.arange(val_data['node_a'].num_nodes, dtype=torch.long), world_size)[rank]
    test_idx = torch.chunk(torch.arange(test_data['node_a'].num_nodes, dtype=torch.long), world_size)[rank]

    dist_batch_size = batch_size // world_size
    num_workers = 0 
    persistent_workers = False

    print('indices created')

    train_loader = NeighborLoader(
                                train_data, 
                                batch_size=dist_batch_size, 
                                shuffle=True,
                                directed=False,
                                num_neighbors=num_neighbors,
                                input_nodes=('node_a', train_idx), 
                                num_workers=num_workers, 
                                persistent_workers=persistent_workers,
                                drop_last=True) 

    val_loader = NeighborLoader(
                                val_data, 
                                batch_size=dist_batch_size,
                                directed=False,
                                num_neighbors=num_neighbors,
                                input_nodes=('node_a', val_idx),
                                num_workers=num_workers, 
                                persistent_workers=persistent_workers,
                                drop_last=True)

    test_loader = NeighborLoader(
                                test_data, 
                                batch_size=dist_batch_size,
                                directed=False,
                                num_neighbors=num_neighbors,
                                input_nodes=('node_a', test_idx), 
                                num_workers=num_workers, 
                                persistent_workers=persistent_workers,
                                drop_last=True)

    del train_data, val_data, test_data, train_idx, val_idx, test_idx, dist_batch_size, num_workers, persistent_workers
    
    free_gpu_cache(print_gpu=True, rank=rank)

    print('loaders created')

    return train_loader, val_loader, test_loader

def process_edges(batch, device, rank=0):

    for edge in batch.edge_types:
        num_edges = batch[edge].edge_index.shape[1]
        shuffle_target = torch.randperm(num_edges).to(device)
        neg_targets = batch[edge].edge_index[1, shuffle_target]
        neg_edge_index = torch.clone(batch[edge].edge_index).to(device)
        neg_edge_index[1, :] = neg_targets

        batch[edge].neg_edge_label_index = neg_edge_index
        batch[edge].neg_edge_label = torch.zeros(num_edges, dtype=torch.float, device=device)
        batch[edge].pos_edge_label = torch.ones(num_edges, dtype=torch.float, device=device)
        batch[edge].pos_edge_label_index = batch[edge].edge_index

    edge_label_index_dict, edge_label_dict = {}, {}

    for edge in batch.edge_index_dict.keys():
        edge_label_index_dict[edge] = torch.cat([batch[edge].pos_edge_label_index, batch[edge].neg_edge_label_index], dim=1).to(device)
        edge_label_dict[edge] = torch.cat([batch[edge].pos_edge_label, batch[edge].neg_edge_label], dim=0).to(device)

    return edge_label_index_dict, edge_label_dict

class GNNEncoder(nn.Module):
    def __init__(self, hidden_channels, output_channels, dropout, heads, num_layers):
        super().__init__()
        self.layers = []
        self.batch_norms = []
        for i in range(num_layers):
            if i == 0:
                channels = hidden_channels
            elif num_layers == 3 and i == 1:
                channels = hidden_channels
            else:
                channels = output_channels
            layer = GATv2Conv((-1, -1), channels, add_self_loops=False, edge_dim=-1, heads=heads)
            self.add_module('layer_{}'.format(i), layer)
            self.layers.append(layer)
            batch_norm = norm.BatchNorm(heads * channels, allow_single_element=True)
            self.add_module('batch_norm_{}'.format(i), batch_norm)
            self.batch_norms.append(batch_norm)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.relu = torch.nn.ReLU()

    def forward(self, x, edge_index, edge_attr):
        for i in range(len(self.layers)):
            x = checkpoint(self.layers[i], x, edge_index, edge_attr)  # Use gradient checkpointing
            x = self.batch_norms[i](x)
            if i != len(self.layers)-1:
                x = self.relu(x)
            x = self.dropout(x)
        return x

class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels, output_channels, decode_channels, dropout, heads, num_lin, num_conv):
        super().__init__()
        self.layers = []
        self.batch_norms = []
        if num_conv == 1:
            output_channels = hidden_channels
            if num_lin == 1:
                decode_channels = 2 * heads * hidden_channels
        if num_lin == 1:
            decode_channels = 2 * heads * output_channels
        for i in range(num_lin):
            if i == 0 and num_lin != 1:
                layer = Linear(2 * heads * output_channels, decode_channels)
                self.add_module('layer_{}'.format(i), layer)
                self.layers.append(layer)
            elif num_lin == 3 and i == 1:
                layer = Linear(decode_channels, decode_channels)
                self.add_module('layer_{}'.format(i), layer)
                self.layers.append(layer)
            else:
                layer = Linear(decode_channels, 1)
                self.add_module('layer_{}'.format(i), layer)
                self.layers.append(layer)
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, z_dict, edge_label_index):

        forward_dict, forward_dict_rest = {}, {}

        for i, edge in enumerate(edge_label_index.keys()):
            if edge[1] == 'edge_ab' or edge[1] == 'rev_edge_ab':
                row, col = edge_label_index[edge]
                forward_dict[f'z_forward{i}'] = torch.cat([z_dict[edge[0]][row], z_dict[edge[2]][col]], dim=-1)
            else:
                row, col = edge_label_index[edge]
                forward_dict_rest[f'z_forward{i}'] = torch.cat([z_dict[edge[0]][row], z_dict[edge[2]][col]], dim=-1)

        z = torch.cat([x for x in forward_dict.values()], dim=0)
        z_rest = torch.cat([x for x in forward_dict_rest.values()], dim=0)

        for i in range(len(self.layers)):
            z = self.layers[i](z)
            if i != len(self.layers)-1:
                z = self.relu(z)
            z = self.dropout(z)

        for i in range(len(self.layers)):
            z_rest = self.layers[i](z_rest)
            if i != len(self.layers)-1:
                z_rest = self.relu(z_rest)
            z_rest = self.dropout(z_rest)

        return z.view(-1), z_rest.view(-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, metadata):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, output_channels, dropout, heads, num_conv)
        self.encoder = to_hetero(self.encoder, metadata, aggr=aggr)
        self.decoder = EdgeDecoder(hidden_channels, output_channels, decode_channels, dropout, heads, num_lin, num_conv)
        del metadata

    def forward(self, x_dict, edge_index_dict, edge_label_index, edge_attr_dict):
        z_dict = self.encoder(x_dict, edge_index_dict, edge_attr_dict)
        return self.decoder(z_dict, edge_label_index)

def train(data_loader, model, optimizer, criterion, batch_size, rank, device, batch_run=False):
    
    model.train()

    total_examples = total_loss = 0
    roc_auc_sum = 0.0
    roc_auc_count = 0
    skip = False
    
    for i, batch in enumerate(data_loader):

        #Check if the edge type exists in the batch
        edge_index = batch[('node_a', 'edge_ab', 'node_b')].get('edge_index')

        if edge_index is not None and edge_index.size(1) == 0:
            skip = True
        
        if i % 1000 == 0 and rank == 0:
            print(f'batch {i}/{len(data_loader)}')
        
        batch = batch.to(device=device)

        edge_label_index_dict, edge_label_dict = process_edges(batch, device, rank)

        raw_pred, raw_pred_rest = model(batch.x_dict, batch.edge_index_dict, edge_label_index_dict, batch.edge_attr_dict)

        tar = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' in k or 'rev_edge_ab' in k], dim=0)
        tar_rest = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' not in k and 'rev_edge_ab' not in k], dim=0)

        tar = tar.to(device=device)
        tar_rest = tar_rest.to(device=device)
       
        loss_gt = criterion(torch.nn.functional.relu(raw_pred), tar)
        loss_rest = criterion(torch.nn.functional.relu(raw_pred_rest), tar_rest)

        lmda = 0.99
        
        loss = lmda * loss_gt + (1 - lmda) * loss_rest
        loss_val = loss.item()

        loss.backward()

        if (i + 1) % 4 == 0:     #gradient accumulation
            optimizer.step()
            optimizer.zero_grad()

        # Empty edge 
        if skip:
            loss_val = 0
        
        total_examples += batch_size
        total_loss += float(loss_val) * batch_size

        # Compute ROC AUC for this batch and add to sum
        preds = torch.sigmoid(raw_pred.detach()).cpu().numpy()
        tars = tar.cpu().numpy().astype(int)

        # if len(preds) > 0 and len(tars) > 0 and not np.isnan(preds).any() and not np.isnan(tars).any():
        if len(preds) > 0 and len(tars) > 0 and skip is False:
            roc_auc_sum += roc_auc_score(tars, preds)
            roc_auc_count += 1

        if batch_run == True:
            break

    # Compute the average ROC AUC
    roc_auc_avg = roc_auc_sum / roc_auc_count if roc_auc_count > 0 else 0.0
    
    loss_val= total_loss / total_examples
        
    return loss_val, roc_auc_avg


@torch.no_grad()
def test(data_loader, model, criterion, batch_size, rank, device, batch_run=False):

    model.eval()
    
    total_examples = total_loss = 0
    roc_auc_sum = 0.0
    roc_auc_count = 0
    skip = False

    for i, batch in enumerate(data_loader):

        # Check if the edge type exists in the batch
        edge_index = batch[('node_a', 'edge_ab', 'node_b')].get('edge_index')

        if edge_index is not None and edge_index.size(1) == 0:
            skip = True
 
        if i % 1000 == 0 and rank == 0:
            print(f'batch {i}/{len(data_loader)}')
                    
        batch = batch.to(device=device)

        edge_label_index_dict, edge_label_dict = process_edges(batch, device, rank)

        raw_pred, raw_pred_rest = model(batch.x_dict, batch.edge_index_dict, edge_label_index_dict, batch.edge_attr_dict)

        tar = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' in k or 'rev_edge_ab' in k], dim=0)
        tar_rest = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' not in k and 'rev_edge_ab' not in k], dim=0)

        tar = tar.to(device=device)
        tar_rest = tar_rest.to(device=device)

        loss_gt = criterion(torch.nn.functional.relu(raw_pred), tar)
        loss_rest = criterion(torch.nn.functional.relu(raw_pred_rest), tar_rest)

        lmda = 0.99
        
        loss = lmda * loss_gt + (1 - lmda) * loss_rest
        loss_val = loss.item()

        # Empty edge 
        if skip:
            loss_val = 0

        total_examples += batch_size
        total_loss += float(loss_val) * batch_size

        # Compute ROC AUC for this batch and add to sum
        preds = torch.sigmoid(raw_pred.detach()).cpu().numpy()
        tars = tar.cpu().numpy().astype(int)

        if len(preds) > 0 and len(tars) > 0 and skip is False: 
            roc_auc_sum += roc_auc_score(tars, preds)
            roc_auc_count += 1

        if batch_run == True:
            break

    # Compute the average ROC AUC
    roc_auc_avg = roc_auc_sum / roc_auc_count if roc_auc_count > 0 else 0.0
    
    loss_val = total_loss / total_examples
    
    return loss_val, roc_auc_avg

def setup():

    # initialize the process group
    dist.init_process_group("nccl")

def cleanup():
    dist.destroy_process_group()

@record
def run(new_data, epochs):

    rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])

    set_seed()

    free_gpu_cache(print_gpu=True, rank=rank)

    print(f"Running basic DDP example on rank {rank}.")
    setup()

    volume_mount_dir = ''
    dataset_path = os.path.join(volume_mount_dir, 'datasets/graph_data_v2.pt')
    checkpoint_path = os.path.join(volume_mount_dir, 'checkpoints/')

    device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')

    print("Device is ", device)
        
    if rank == 0 and new_data==True:
        # Only process the data in the first process.
        train_data, val_data, test_data = data_process(dataset_path, rank)

        torch.save((train_data, val_data, test_data), 'datasets/processed_data.pt')

        print("data sets saved")
        
        # Create data loaders on rank 0 for testing or other purposes, but they won't be saved or shared.
        train_loader, val_loader, test_loader = data_loaders(batch_size, train_data, val_data, test_data, num_neighbors, rank, world_size)
        train_batch = next(iter(train_loader)) 
        train_batch = train_batch.to(device=device) 

        torch.save((train_batch), 'datasets/batch_data.pt')

        print("data saved")

    dist.barrier()

    if rank != 0 or new_data==False:

        train_data, val_data, test_data = torch.load('datasets/processed_data.pt')

        free_gpu_cache(print_gpu=True, rank=rank)

        train_batch = torch.load('datasets/batch_data.pt')

        print("batch data loaded")

        free_gpu_cache(print_gpu=True, rank=rank)

    dist.barrier()

    if rank == 0:
        print("all data loaded")

    # Recreate the data loaders in each process.
    train_loader, val_loader, test_loader = data_loaders(batch_size, train_data, val_data, test_data, num_neighbors, rank, world_size)

    if rank == 0:
        print("dataloaders created")
    
    free_gpu_cache(print_gpu=True, rank=rank)

    train_batch = train_batch.to(rank)

    edge_label_index_dict, _ = process_edges(train_batch, device, rank)
    
    if rank == 0:
        print('edge dictionaries created')

    free_gpu_cache(print_gpu=True, rank=rank)

    dist.barrier()

    print(f"# Train Batches Rank {rank}: {len(train_loader)}\n")
    print(f"# Val Batches: Rank {rank}: {len(val_loader)}\n")
    print(f"# Test Batches: Rank {rank}: {len(test_loader)}\n")

    for i in range(len(aws_dict['heads'])):

        run_name = f'AWS_RUN_{i}'
        print(f'{run_name} on rank {rank}\n')

        if rank == 0:

            wandb.init(name=run_name,
            notes='aws dist',
            project="gene_disease_dist_aws_manual_v2", 
            entity="gurugecl")

        num_lin = aws_dict['num_lin'][i]
        num_conv = aws_dict['num_conv'][i]
        heads = aws_dict['heads'][i]
        decode_channels = aws_dict['decode_channels'][i]
        hidden_channels = aws_dict['hidden_channels'][i]
        output_channels = aws_dict['output_channels'][i]

        if rank==0:
            print("num_lin: ", num_lin)
            print("num_conv: ", num_conv)
            print("heads: ", heads)
            print("hidden_channels: ", hidden_channels)
            print("decode_channels: ", decode_channels)
            print("output_channels: ", output_channels)


        model_filename = f"completed_model_AWS_RUN_{i}.pt"  # change this to the name of the model you want to load

        model_file_path = os.path.join(checkpoint_path, model_filename)

        if os.path.isfile(model_file_path):
            print("EXISTING MODEL")
            model, start_epoch = load_checkpoint_model(model_file_path, hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, train_batch.metadata(), device)
            model = model.to(device)
        else:
            if rank == 0:
                print("NEW MODEL")
            start_epoch = 0
            model = Model(hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, train_batch.metadata()).to(device) 
                
        if rank == 0:
            print('model loaded')

        # Perform dummy forward pass for correct initialization
        with torch.no_grad():
            model.forward(train_batch.x_dict, train_batch.edge_index_dict, edge_label_index_dict, train_batch.edge_attr_dict)
        if rank == 0:
            print("dummy forward")

        dist.barrier()

        # Apply DDP
        model = DDP(model, device_ids=[rank], output_device=rank) #, gradient_as_bucket_view=True)
        if rank == 0:
            print("DDP")

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = torch.nn.BCEWithLogitsLoss()
        # compute loss inside of DataParallel due to imbalance
        if torch.cuda.device_count() > 1:
            criterion = CriterionParallel(criterion)

        torch.autograd.set_detect_anomaly(True)

        start_time = time.time()

        dist.barrier()

        for epoch in np.arange(start_epoch, epochs): #10
            print(f'rank {rank} epoch {epoch}\n')
            print(f'train rank {rank}\n')
            if rank==0:
                epoch_start_time = time.time()
            train_loss, train_roc = train(train_loader, model, optimizer, criterion, batch_size, rank, device, batch_run=False)
            print(f'val rank {rank}\n')
            val_loss, val_roc = test(val_loader, model, criterion, batch_size, rank, device, batch_run=False)
            print(f'test rank {rank}\n')
            test_loss, test_roc = test(test_loader, model, criterion, batch_size, rank, device, batch_run=True)
            
            comb_roc = train_roc * (1/3) + val_roc * (2/3)  

            print(f'epoch {epoch} complete on rank {rank}\n')      
        
            if rank == 0:
                epoch_end_time = time.time()

                print(f'Epoch: {epoch:03d}, LOSS: (Train: {train_loss:.4f}, Val: {val_loss:.4f}, Test: {test_loss:.4f}), ROC: (Train: {train_roc:.4f}, '
                f'Val: {val_roc:.4f}, Test: {test_roc:.4f}, Comb: {comb_roc:.4f})')

                epoch_time = epoch_end_time - epoch_start_time

                wandb.log({"Epoch": epoch,        
                "Train Loss":train_loss,        
                "Train ROC": train_roc,        
                "Val Loss": val_loss,        
                "Val ROC": val_roc,
                "Test Loss": test_loss,        
                "Test ROC": test_roc,
                "Comb_ROC": comb_roc,
                "Epoch Time": epoch_time})

                torch.save(model.state_dict(), f'checkpoints/checkpoint_{run_name}_epoch_{epoch}.pt')
                print("model saved")

                print(f'Run: {run_name} Epoch: {epoch} Execution time: {epoch_time} seconds')

                with open(f'checkpoints/time_{run_name}.txt', 'a') as f:  # Open file in append mode
                    f.write(f'Run: {run_name} Epoch: {epoch} Time: {epoch_time} seconds\n')

            dist.barrier() # for each epoch
            
            free_gpu_cache(print_gpu=True, rank=rank)
                        
        if rank == 0:

            end_time = time.time()

            total_time = end_time - start_time  # Calculate execution time

            print(f'Total execution time: {total_time} seconds')

            with open(f'checkpoints/time_{run_name}.txt', 'a') as f:  # Open file in append mode
                f.write(f'Run: {run_name} Total Time: {total_time} seconds\n') 

            torch.save(model, f'checkpoints/completed_model_{run_name}.pt')

            wandb.init().finish()

        dist.barrier() # for each aws config

    print(f"Clean up run on rank: {rank}", flush=True)

    cleanup()

    print(f"Clean up completed on rank {rank}", flush=True)


if __name__ == '__main__':

    world_size = torch.cuda.device_count()
    print('Using', world_size, 'GPUs!')

    run(new_data, epochs)

    if instance_stop:
        print("STOPPING INSTANCE")
        instance_id = get_instance_id()
        stop_instance(instance_id)

Have you tried set
NCCL_DESYNC_DEBUG to be 1 and NCCL_DEBUG_SUBSYS = COLL?

Also are you only using 4 GPUs?

Yes in this case I am using 4 but have tried with 8 as well. So I added the commands you suggested. Below is the output. Unfortunately, it just still seems to say its timing out but I don’t see the specific root cause but please let me know what you think.

172-46-5-209:2564:2543 [1] NCCL INFO AllReduce: opCount 106cf sendbuff 0x7f847bd46600 recvbuff 0x7f847bd46600 count 9475072 datatype 7 op 0 root 0 comm 0x560187d88af0 [nranks=4] stream 0x560187b2b130
172-46-5-209:2564:2543 [1] NCCL INFO AllReduce: opCount 106d0 sendbuff 0x7f847e16b600 recvbuff 0x7f847e16b600 count 9650688 datatype 7 op 0 root 0 comm 0x560187d88af0 [nranks=4] stream 0x560187b2b130
172-46-5-209:2564:2543 [1] NCCL INFO AllReduce: opCount 106d1 sendbuff 0x7f848063be00 recvbuff 0x7f848063be00 count 10780800 datatype 7 op 0 root 0 comm 0x560187d88af0 [nranks=4] stream 0x560187b2b130
172-46-5-209:2564:2543 [1] NCCL INFO AllReduce: opCount 106d2 sendbuff 0x7f8482f5c000 recvbuff 0x7f8482f5c000 count 9481600 datatype 7 op 0 root 0 comm 0x560187d88af0 [nranks=4] stream 0x560187b2b130
test rank 3

epoch 1 complete on rank 3

ip-172-31-4-202:2634:2634 [3] NCCL INFO AllReduce: opCount 106b5 sendbuff 0x7f8ca78a2c00 recvbuff 0x7f8ca78a2c00 count 1 datatype 1 op 0 root 0 comm 0x5616d5917f40 [nranks=4] stream 0x5616d4639a90
Memory Usage On Rank  3
| ID | GPU  | MEM |
-------------------
|  0 | 100% | 80% |
|  1 | 100% | 63% |
|  2 | 100% | 67% |
|  3 |  46% | 52% |
RAM Used (GB): 182.490628096
ip-172-31-4-202:2634:2634 [3] NCCL INFO AllReduce: opCount 106b6 sendbuff 0x7f8ca78a2c00 recvbuff 0x7f8ca78a2c00 count 1 datatype 1 op 0 root 0 comm 0x5616d5917f40 [nranks=4] stream 0x5616d4639a90
[E ProcessGroupNCCL.cpp:828] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5403363 milliseconds before timing out.
	 - [2] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5403465 milliseconds before timing out.
	 - [0] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5403280 milliseconds before timing out.
	 - [1] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
Clean up run on rank: 3
Clean up completed on rank 3
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'terminate called after throwing an instance of 'std::runtime_errorterminate called after throwing an instance of 'std::runtime_error'
std::runtime_error'
'
  what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5403363 milliseconds before timing out.
	 - [2] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE  what():    what():  
[Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5403280 milliseconds before timing out.
	 - [1] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE[Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5403465 milliseconds before timing out.
	 - [0] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE

Fatal Python error: AbortedFatal Python error: Fatal Python error: 

AbortedAbortedThread 0x



00007fc58dfff700Thread 0x (most recent call first):
Thread 0x00007f8565caa700  <no Python frame>
00007fe224ffd700 (most recent call first):

 (most recent call first):
  <no Python frame>
Thread 0x  <no Python frame>

00007fc59efdb700
Thread 0x (most recent call first):
Thread 0x00007f8f20822700  <no Python frame>
00007fe2257fe700 (most recent call first):

 (most recent call first):
  <no Python frame>
Thread 0x  <no Python frame>

00007fcf5c9cd700
Thread 0x (most recent call first):
00007f8f704ed700Thread 0x  <no Python frame>
 (most recent call first):
00007fe225fff700
  <no Python frame>
 (most recent call first):
Thread 0x
  <no Python frame>
00007fcfbd940700Thread 0x
 (most recent call first):
00007f8f48b0d700  <no Python frame>
Thread 0x (most recent call first):

00007fe232fa9700  <no Python frame>
Thread 0x (most recent call first):

00007fd0bc184740  <no Python frame>
Thread 0x (most recent call first):

  File 00007f9081838740Thread 0x" (most recent call first):
00007fe233caa700/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py  File  (most recent call first):
""  File /home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py45, line " in 45, line index_select in 324
index_select in   File wait
"
  File   File /home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py"""/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line ", line 67, line 67 in 622 in filter_node_store_ in filter_node_store_
wait
  File 
  File "  File ""/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py""", line , line 136, line 136 in 238 in filter_hetero_data in filter_hetero_data
_loop_check_status
  File 
  File "  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/node_loader.py"/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/node_loader.py"/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py"", line , line , line 154255154 in  in  in filter_fncheck_network_statusfilter_fn


  File   File   File """/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/base.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/base.py""", line , line , line 3697536 in  in  in __next__run__next__


  File   File   File ""/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py"/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py"/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line ", line 427427, line  in  in 1038traintrain in 

_bootstrap_inner  File   File 
""  File /dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py"""/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 995 in _bootstrap
, line , line 
855855Thread 0x in 00007febe4822700 in run (most recent call first):
run
  File 
  File "  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py"/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py"/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line ", line 324, line 346 in 346 in wait in wrapper
wrapper
  File 
  File "  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py"/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py"/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py", line "622 in wait
, line   File , line 939"939 in /home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py in <module>"<module>

, line 238 in _loop_check_status
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 273 in check_stop_status
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 975 in run
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 1038 in _bootstrap_inner
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 995 in _bootstrap

Thread 0x00007fec0f7fe700 (most recent call first):
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/lib/sock_client.py", line 255 in _read_packet_bytes
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/lib/sock_client.py", line 285 in read_server_response
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/interface/router_sock.py", line 27 in _read_message
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/interface/router.py", line 70 in message_loop
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 975 in run
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 1038 in _bootstrap_inner
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 995 in _bootstrap

Thread 0x00007fed25788740 (most recent call first):
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line 45 in index_select
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line 67 in filter_node_store_
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line 136 in filter_hetero_data
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/node_loader.py", line 154 in filter_fn
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/base.py", line 36 in __next__
  File "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py", line 427 in train
  File "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py"
Extension modules: yaml._yaml, line 855 in run
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py"
Extension modules: , line yaml._yaml346 in wrapper
  File "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py", line 939 in <module>
, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.core._multiarray_umath, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.core._multiarray_tests, numpy.random._common, numpy.random.bit_generator, numpy.linalg._umath_linalg, numpy.random._bounded_integers, numpy.fft._pocketfft_internal, numpy.random._mt19937, numpy.random._common, numpy.random.mtrand, numpy.random.bit_generator, numpy.random._philox, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._mt19937, numpy.random._sfc64, numpy.random.mtrand, numpy.random._generator, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, , google._upb._messagenumpy.random._generator, , psutil._psutil_linuxgoogle._upb._message, psutil._psutil_posix, psutil._psutil_linux, torch._C, psutil._psutil_posix, torch._C._fft, torch._C._linalg, torch._C, torch._C._nested, , 
Extension modules: torch._C._ffttorch._C._nnyaml._yaml, , torch._C._linalgtorch._C._sparse, , torch._C._nestedtorch._C._special, torch._C._nn, torch._C._sparse, torch._C._special, numpy.core._multiarray_umath, , numpy.core._multiarray_testsgmpy2.gmpy2, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, gmpy2.gmpy2, numpy.random._common, numpy.random.bit_generator, scipy._lib._ccallback_c, numpy.random._bounded_integers, scipy.sparse._sparsetools, numpy.random._mt19937, , _csparsetoolsnumpy.random.mtrand, scipy._lib._ccallback_c, , scipy.sparse._csparsetoolsnumpy.random._philox, scipy.sparse._sparsetools, , scipy.sparse.linalg._isolve._iterativenumpy.random._pcg64, _csparsetools, , scipy.linalg._fblasnumpy.random._sfc64, scipy.sparse._csparsetools, , scipy.linalg._flapacknumpy.random._generator, scipy.sparse.linalg._isolve._iterative, scipy.linalg._cythonized_array_utils, scipy.linalg._fblas, scipy.linalg._flinalg, scipy.linalg._flapack, google._upb._message, scipy.linalg._solve_toeplitz, scipy.linalg._cythonized_array_utils, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg._flinalg, scipy.linalg.cython_lapack, scipy.linalg._solve_toeplitz, scipy.linalg.cython_blas, , scipy.linalg._matfuncs_sqrtm_triupsutil._psutil_linux, scipy.linalg._matfuncs_expm, scipy.linalg.cython_lapack, psutil._psutil_posix, scipy.linalg._decomp_update, scipy.linalg.cython_blas, scipy.sparse.linalg._dsolve._superlu, scipy.linalg._matfuncs_expm, , scipy.sparse.linalg._eigen.arpack._arpack, torch._Cscipy.linalg._decomp_update, scipy.sparse.csgraph._tools, , torch._C._fftscipy.sparse.linalg._dsolve._superlu, scipy.sparse.csgraph._shortest_path, , torch._C._linalgscipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._traversal, torch._C._nested, , scipy.sparse.csgraph._toolsscipy.sparse.csgraph._min_spanning_tree, torch._C._nn, , scipy.sparse.csgraph._shortest_pathscipy.sparse.csgraph._flow, torch._C._sparse, , scipy.sparse.csgraph._traversalscipy.sparse.csgraph._matching, torch._C._special, , scipy.sparse.csgraph._min_spanning_treescipy.sparse.csgraph._reordering, scipy.sparse.csgraph._flow, scipy.spatial._ckdtree, scipy.sparse.csgraph._matching, scipy._lib.messagestream, scipy.sparse.csgraph._reordering, scipy.spatial._qhull, scipy.spatial._ckdtree, scipy.spatial._voronoi, scipy._lib.messagestream, , scipy.spatial._distance_wrapgmpy2.gmpy2, scipy.spatial._qhull, scipy.spatial._hausdorff, scipy.spatial._voronoi, scipy.special._ufuncs_cxx, scipy.spatial._distance_wrap, scipy.special._ufuncs, scipy.spatial._hausdorff, scipy.special._specfun, scipy.special._ufuncs_cxx, scipy.special._comb, , scipy._lib._ccallback_cscipy.special._ufuncs, scipy.special._ellip_harm_2, , , scipy.sparse._sparsetoolsscipy.special._specfunscipy.spatial.transform._rotation, , , scipy.special._combscipy.cluster._vq_csparsetools, , scipy.special._ellip_harm_2scipy.cluster._hierarchy, scipy.sparse._csparsetools, , scipy.spatial.transform._rotationscipy.cluster._optimal_leaf_ordering, scipy.sparse.linalg._isolve._iterative, scipy.cluster._vq, scipy.linalg._fblas, scipy.cluster._hierarchy, scipy.linalg._flapack, scipy.cluster._optimal_leaf_ordering, scipy.linalg._cythonized_array_utils, scipy.linalg._flinalg, scipy.linalg._solve_toeplitz, sklearn.__check_build._check_build, scipy.linalg._matfuncs_sqrtm_triu, sklearn.utils.murmurhash, scipy.linalg.cython_lapack, sklearn.__check_build._check_build, , scipy.linalg.cython_blasnumpy.linalg.lapack_lite, , , sklearn.utils.murmurhashscipy.linalg._matfuncs_expmscipy.ndimage._nd_image, scipy.linalg._decomp_update, , numpy.linalg.lapack_lite_ni_label, scipy.sparse.linalg._dsolve._superlu, , , scipy.ndimage._ni_labelscipy.sparse.linalg._eigen.arpack._arpackscipy.ndimage._nd_image, , scipy.optimize._minpack2scipy.sparse.csgraph._tools, _ni_label, , scipy.optimize._group_columnsscipy.sparse.csgraph._shortest_path, scipy.ndimage._ni_label, , scipy.optimize._trlib._trlibscipy.sparse.csgraph._traversal, scipy.optimize._minpack2, , scipy.optimize._lbfgsbscipy.sparse.csgraph._min_spanning_tree, scipy.optimize._group_columns, , scipy.sparse.csgraph._flow, _moduleTNCscipy.optimize._trlib._trlib, scipy.sparse.csgraph._matching, , scipy.optimize._moduleTNCscipy.optimize._lbfgsb, scipy.sparse.csgraph._reordering, scipy.optimize._cobyla, _moduleTNC, , scipy.spatial._ckdtreescipy.optimize._slsqp, scipy.optimize._moduleTNC, , scipy._lib.messagestream, scipy.optimize._minpackscipy.optimize._cobyla, scipy.spatial._qhull, , scipy.optimize._slsqpscipy.optimize._lsq.givens_elimination, scipy.spatial._voronoi, scipy.optimize._minpack, , scipy.optimize._zerosscipy.spatial._distance_wrap, scipy.optimize._lsq.givens_elimination, , scipy.optimize.__nnlsscipy.spatial._hausdorff, scipy.optimize._zeros, , scipy.optimize._highs.cython.src._highs_wrapperscipy.special._ufuncs_cxx, scipy.optimize.__nnls, , scipy.optimize._highs._highs_wrapperscipy.special._ufuncs, scipy.optimize._highs.cython.src._highs_wrapper, , scipy.optimize._highs.cython.src._highs_constants, scipy.special._specfunscipy.optimize._highs._highs_wrapper, scipy.optimize._highs._highs_constants, , scipy.special._combscipy.optimize._highs.cython.src._highs_constants, scipy.linalg._interpolative, , scipy.special._ellip_harm_2scipy.optimize._highs._highs_constants, scipy.optimize._bglu_dense, , scipy.spatial.transform._rotationscipy.linalg._interpolative, scipy.optimize._lsap, scipy.cluster._vq, scipy.optimize._bglu_dense, , scipy.cluster._hierarchyscipy.optimize._direct, scipy.optimize._lsap, scipy.cluster._optimal_leaf_ordering, scipy.integrate._odepack, scipy.optimize._direct, scipy.integrate._quadpack, scipy.integrate._odepack, scipy.integrate._vode, scipy.integrate._quadpack, scipy.integrate._dop, scipy.integrate._vode, scipy.integrate._lsoda, scipy.integrate._dop, sklearn.__check_build._check_build, , scipy.special.cython_specialscipy.integrate._lsoda, sklearn.utils.murmurhash, scipy.stats._stats, scipy.special.cython_special, , scipy.stats.beta_ufunc, numpy.linalg.lapack_litescipy.stats._stats, scipy.stats._boost.beta_ufunc, , scipy.ndimage._nd_imagescipy.stats.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.beta_ufunc, , _ni_labelscipy.stats._boost.binom_ufunc, scipy.stats.binom_ufunc, , scipy.ndimage._ni_labelscipy.stats.nbinom_ufunc, scipy.stats._boost.binom_ufunc, , scipy.stats._boost.nbinom_ufunc, scipy.optimize._minpack2scipy.stats.nbinom_ufunc, scipy.stats.hypergeom_ufunc, , scipy.optimize._group_columnsscipy.stats._boost.nbinom_ufunc, scipy.stats._boost.hypergeom_ufunc, , scipy.optimize._trlib._trlibscipy.stats.hypergeom_ufunc, scipy.stats.ncf_ufunc, , scipy.optimize._lbfgsbscipy.stats._boost.hypergeom_ufunc, scipy.stats._boost.ncf_ufunc, , scipy.stats.ncf_ufunc, _moduleTNCscipy.stats.ncx2_ufunc, scipy.stats._boost.ncf_ufunc, , scipy.optimize._moduleTNCscipy.stats._boost.ncx2_ufunc, scipy.stats.ncx2_ufunc, scipy.optimize._cobyla, scipy.stats.nct_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.optimize._slsqp, scipy.stats._boost.nct_ufunc, scipy.stats.nct_ufunc, scipy.optimize._minpack, scipy.stats.skewnorm_ufunc, scipy.stats._boost.nct_ufunc, scipy.optimize._lsq.givens_elimination, scipy.stats._boost.skewnorm_ufunc, scipy.stats.skewnorm_ufunc, scipy.optimize._zeros, , scipy.stats.invgauss_ufuncscipy.stats._boost.skewnorm_ufunc, scipy.optimize.__nnls, scipy.stats._boost.invgauss_ufunc, scipy.stats.invgauss_ufunc, scipy.optimize._highs.cython.src._highs_wrapper, , scipy.interpolate._fitpackscipy.stats._boost.invgauss_ufunc, scipy.optimize._highs._highs_wrapper, scipy.interpolate.dfitpack, , scipy.interpolate._fitpackscipy.optimize._highs.cython.src._highs_constants, scipy.interpolate._bspl, , scipy.interpolate.dfitpackscipy.optimize._highs._highs_constants, scipy.interpolate._ppoly, , scipy.interpolate._bsplscipy.linalg._interpolative, scipy.interpolate.interpnd, scipy.interpolate._ppoly, scipy.optimize._bglu_dense, scipy.interpolate._rbfinterp_pythran, scipy.interpolate.interpnd, scipy.optimize._lsap, scipy.interpolate._rgi_cython, scipy.interpolate._rbfinterp_pythran, scipy.optimize._direct, , scipy.stats._biasedurnscipy.interpolate._rgi_cython, scipy.integrate._odepack, , scipy.stats._levy_stable.levystscipy.stats._biasedurn, scipy.integrate._quadpack, , scipy.stats._stats_pythran, scipy.stats._levy_stable.levystscipy.integrate._vode, scipy._lib._uarray._uarray, , scipy.stats._stats_pythranscipy.integrate._dop, , , scipy.stats._statlibscipy.integrate._lsodascipy._lib._uarray._uarray, scipy.stats._mvn, , scipy.special.cython_specialscipy.stats._statlib, scipy.stats._sobol, scipy.stats._stats, , scipy.stats._mvnscipy.stats._qmc_cy, scipy.stats.beta_ufunc, , scipy.stats._sobolscipy.stats._rcont.rcont, scipy.stats._boost.beta_ufunc, scipy.stats._qmc_cy, sklearn.utils._isfinite, scipy.stats.binom_ufunc, scipy.stats._rcont.rcont, , sklearn.utils._openmp_helpersscipy.stats._boost.binom_ufunc, sklearn.utils._isfinite, , sklearn.utils._logistic_sigmoidscipy.stats.nbinom_ufunc, sklearn.utils._openmp_helpers, , sklearn.utils.sparsefuncs_fastscipy.stats._boost.nbinom_ufunc, sklearn.utils._logistic_sigmoid, , sklearn.preprocessing._csr_polynomial_expansion, scipy.stats.hypergeom_ufuncsklearn.utils.sparsefuncs_fast, sklearn.utils._typedefs, , scipy.stats._boost.hypergeom_ufuncsklearn.preprocessing._csr_polynomial_expansion, sklearn.utils._readonly_array_wrapper, , scipy.stats.ncf_ufuncsklearn.utils._typedefs, sklearn.metrics._dist_metrics, , scipy.stats._boost.ncf_ufuncsklearn.utils._readonly_array_wrapper, sklearn.metrics.cluster._expected_mutual_info_fast, , scipy.stats.ncx2_ufunc, sklearn.metrics._dist_metricssklearn.metrics._pairwise_distances_reduction._datasets_pair, scipy.stats._boost.ncx2_ufunc, , sklearn.metrics.cluster._expected_mutual_info_fastsklearn.utils._cython_blas, scipy.stats.nct_ufunc, , sklearn.metrics._pairwise_distances_reduction._datasets_pairsklearn.metrics._pairwise_distances_reduction._base, scipy.stats._boost.nct_ufunc, , sklearn.utils._cython_blassklearn.metrics._pairwise_distances_reduction._middle_term_computer, scipy.stats.skewnorm_ufunc, sklearn.metrics._pairwise_distances_reduction._base, , sklearn.utils._heapscipy.stats._boost.skewnorm_ufunc, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, , sklearn.utils._sortingscipy.stats.invgauss_ufunc, sklearn.utils._heap, , sklearn.metrics._pairwise_distances_reduction._argkminscipy.stats._boost.invgauss_ufunc, sklearn.utils._sorting, , sklearn.utils._vector_sentinelscipy.interpolate._fitpack, sklearn.metrics._pairwise_distances_reduction._argkmin, , sklearn.metrics._pairwise_distances_reduction._radius_neighbors, scipy.interpolate.dfitpacksklearn.utils._vector_sentinel, sklearn.metrics._pairwise_fast, , scipy.interpolate._bsplsklearn.metrics._pairwise_distances_reduction._radius_neighbors, , , scipy.interpolate._ppolysklearn.metrics._pairwise_fastnumba.core.typeconv._typeconv, scipy.interpolate.interpnd, numba.core.typeconv._typeconv, numba._helperlib, scipy.interpolate._rbfinterp_pythran, numba._helperlib, numba._dynfunc, scipy.interpolate._rgi_cython, numba._dynfunc, numba._dispatcher, scipy.stats._biasedurn, numba._dispatcher, , scipy.stats._levy_stable.levystnumba.core.runtime._nrt_python, numba.core.runtime._nrt_python, scipy.stats._stats_pythran, , numba.np.ufunc._internalnumba.np.ufunc._internal, scipy._lib._uarray._uarray, , numba.experimental.jitclass._boxnumba.experimental.jitclass._box, scipy.stats._statlib, , numba.mviewbufnumba.mviewbuf, scipy.stats._mvn, , numba.types.itertools, numba.types.itertoolsscipy.stats._sobol, scipy.stats._qmc_cy (total: 159) (total: 
159, )scipy.stats._rcont.rcont
, sklearn.utils._isfinite, sklearn.utils._openmp_helpers, sklearn.utils._logistic_sigmoid, sklearn.utils.sparsefuncs_fast, sklearn.preprocessing._csr_polynomial_expansion, sklearn.utils._typedefs, sklearn.utils._readonly_array_wrapper, sklearn.metrics._dist_metrics, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_fast, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, numba.mviewbuf, numba.types.itertools (total: 159)
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 2631) of binary: /home/ubuntu/miniconda/envs/aws/bin/python
Traceback (most recent call last):
  File "/home/ubuntu/miniconda/envs/aws/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
=====================================================
Neo4j_Gene_Trait_LP_V2_awsII.py FAILED
-----------------------------------------------------
Failures:
[1]:
  time      : 2023-07-31_23:31:14
  host      : ip-172-31-4-202
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 2632)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2632
[2]:
  time      : 2023-07-31_23:31:14
  host      : ip-172-31-4-202
  rank      : 2 (local_rank: 2)
  exitcode  : -6 (pid: 2633)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2633
-----------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-07-31_23:31:14
  host      : ip-172-31-4-202
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 2631)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2631
=====================================================


Are you using only two processes here? because usually you need to create one process for each GPU.

No Im using 4, one for each gpu but 0,1,2 are failing because they are timing out while only 3 finishes

Got it. Your model is a little bit complicated so I have some questions here:

  1. Is the graph in each rank the same or it is somehow different?
  2. Can you also try with export TORCH_DISTRIBUTED_DEBUG=INFO?

For DDP, all-reduce happens when we synchronize the gradients in the backward. From the log it looks like rank 0-2 all joined AR but somehow get timeout, might be the input shape is not equal? That’s why I am asking if the graph is same or not.

  1. No they are all copies of the same graph
  2. Sure, at the end are some of the log output

Well the number of batches in the dataloaders are slightly different on each rank as you can see here from the logs

# Train Batches Rank 1: 1052
# Train Batches Rank 2: 1052
# Train Batches Rank 3: 1051

# Val Batches: Rank 1: 1052
# Val Batches: Rank 2: 1052

# Test Batches: Rank 1: 1052

# Train Batches Rank 0: 1052
# Test Batches: Rank 2: 1052
# Val Batches: Rank 3: 1051

AWS_RUN_0 on rank 1

AWS_RUN_0 on rank 2

# Test Batches: Rank 3: 1051
# Val Batches: Rank 0: 1052

AWS_RUN_0 on rank 3
# Test Batches: Rank 0: 1052

We use this chunking method before the dataloader because the pyg sampler was not working for a heterogenous graph so some loaders have an extra batch

train_idx = torch.chunk(torch.arange(train_data['gene'].num_nodes, dtype=torch.long), world_size)[rank]
    val_idx = torch.chunk(torch.arange(val_data['gene'].num_nodes, dtype=torch.long), world_size)[rank]
    test_idx = torch.chunk(torch.arange(test_data['gene'].num_nodes, dtype=torch.long), world_size)[rank]

    dist_batch_size = batch_size // world_size
    num_workers = 0 # 4 * world_size
    persistent_workers = False

    print('indices created')

    train_loader = NeighborLoader(
                                train_data, 
                                batch_size=dist_batch_size, 
                                shuffle=True,
                                directed=False,
                                num_neighbors=num_neighbors,
                                input_nodes=('gene', train_idx), 
                                num_workers=num_workers, 
                                persistent_workers=persistent_workers,
                                drop_last=True)

Here are some of the logs

/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/sampler/neighbor_sampler.py:50: UserWarning: Using '{self.__class__.__name__}' without a 'pyg-lib' installation is deprecated and will be removed soon. Please install 'pyg-lib' for accelerated neighborhood sampling
  warnings.warn("Using '{self.__class__.__name__}' without a "
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/sampler/neighbor_sampler.py:50: UserWarning: Using '{self.__class__.__name__}' without a 'pyg-lib' installation is deprecated and will be removed soon. Please install 'pyg-lib' for accelerated neighborhood sampling
  warnings.warn("Using '{self.__class__.__name__}' without a "
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/sampler/neighbor_sampler.py:50: UserWarning: Using '{self.__class__.__name__}' without a 'pyg-lib' installation is deprecated and will be removed soon. Please install 'pyg-lib' for accelerated neighborhood sampling
  warnings.warn("Using '{self.__class__.__name__}' without a "
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/sampler/neighbor_sampler.py:50: UserWarning: Using '{self.__class__.__name__}' without a 'pyg-lib' installation is deprecated and will be removed soon. Please install 'pyg-lib' for accelerated neighborhood sampling
  warnings.warn("Using '{self.__class__.__name__}' without a "

rank 1 epoch 1

train rank 1

| ID | GPU | MEM |
------------------
|  0 |  7% | 52% |
|  1 | 29% | 42% |
|  2 | 80% | 70% |
|  3 | 69% | 52% |
RAM Used (GB): 182.646919168
rank 0 epoch 1

train rank 0

training model
Memory Usage On Rank  2
batch 0/1052
| ID | GPU | MEM |
------------------
|  0 |  7% | 52% |
|  1 | 29% | 42% |
|  2 | 51% | 41% |
|  3 |  0% | 52% |
RAM Used (GB): 183.017709568
rank 2 epoch 1

train rank 2

batch 1000/1052
val rank 3

test rank 3

epoch 1 complete on rank 3

Memory Usage On Rank  3
| ID | GPU  | MEM |
-------------------
|  0 | 100% | 80% |
|  1 | 100% | 63% |
|  2 | 100% | 67% |
|  3 |  49% | 52% |
RAM Used (GB): 182.39911936
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5400537 milliseconds before timing out.
	 - [0] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5400292 milliseconds before timing out.
	 - [1] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
[E ProcessGroupNCCL.cpp:828] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5400393 milliseconds before timing out.
	 - [2] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
Clean up run on rank: 3
Clean up completed on rank 3
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
terminate called after throwing an instance of 'terminate called after throwing an instance of 'std::runtime_errorstd::runtime_error'
'
  what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5400292 milliseconds before timing out.
	 - [1] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE  what():  
[Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5400537 milliseconds before timing out.
	 - [0] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE
terminate called after throwing an instance of 'std::runtime_error'
Fatal Python error: AbortedFatal Python error: 

Aborted

Thread 0x  what():  Thread 0x00007ef1f9fff700[Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=5400000) ran for 5400393 milliseconds before timing out.
	 - [2] Timeout at collective: ALLREDUCE, #67254
	 - To our best knowledge, the lagging/dead/mismatched ranks that caused the desync are:
	   - [0, 1, 2] joined but didn't finish collective #67254 (count from 1)
	 - Snapshot of ranks' latest states:
	   #67254 started ranks:
	     [0, 1, 2] started ALLREDUCE
	   #67255 started ranks:
	     [3] started ALLREDUCE00007f74b6ffd700 (most recent call first):
 (most recent call first):
  <no Python frame>

  <no Python frame>


Thread 0xThread 0x00007ef1f97fe70000007f74b77fe700 (most recent call first):
 (most recent call first):
  <no Python frame>
  <no Python frame>


Thread 0xThread 0x00007efbbcb0d70000007f74b7fff700 (most recent call first):
 (most recent call first):
  <no Python frame>
  <no Python frame>


Thread 0xThread 0x00007efc2cb3d700Fatal Python error: 00007f74c0ffd700 (most recent call first):
Aborted (most recent call first):
  <no Python frame>
  <no Python frame>




Thread 0xThread 0xThread 0x00007efd2746974000007f74c17fe700 (most recent call first):
00007f0050b0d700 (most recent call first):
  File  (most recent call first):
  File "  <no Python frame>
"
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.pyThread 0x""00007f006681d700, line  (most recent call first):
, line 45  <no Python frame>
324 in 
 in index_selectThread 0xwait
00007f006701e700
  File  (most recent call first):
  File ""  <no Python frame>
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py
""Thread 0x, line , line 00007f0085fff70067622 (most recent call first):
 in  in   <no Python frame>
filter_node_store_wait


  File Thread 0x  File 00007f018ae3d740"" (most recent call first):
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py  File "", line ", line 136238/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py in " in filter_hetero_data_loop_check_status, line 

45  File   File  in ""index_select/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/node_loader.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py
""  File , line , line "154255/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py in  in "filter_fncheck_network_status

, line   File   File 67"" in /home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/base.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.pyfilter_node_store_""
, line , line   File 36975 in " in __next__/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.pyrun
"
  File   File ", line "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py136/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py" in ", line filter_hetero_data, line 428
1038 in   File  in train"_bootstrap_inner
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/node_loader.py
  File "  File ""/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py, line ""154, line  in 995 in filter_fn_bootstrap
, line 
  File 856
" in Thread 0x/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/base.pyrun00007f74c1fff700"
 (most recent call first):
  File , line   File "36"/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py in /home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py"__next__", line 
, line   File 346324" in  in wrapper/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.pywait
"
  File , line   File "428"/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py in /home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py"train"
, line   File 622, line " in 940/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.pywait in "
<module>  File 
"/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py", line 238 in _loop_check_status
, line   File 856" in /home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/wandb_run.py"run, line 
273  File  in check_stop_status"
/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py  File ""/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py, line "346, line  in 975 in wrapperrun

  File   File ""/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py"", line 1038 in _bootstrap_inner
  File ", line /home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py940" in , line 995<module> in 
_bootstrap

Thread 0x00007f7ea3fff700 (most recent call first):
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/lib/sock_client.py", line 255 in _read_packet_bytes
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/lib/sock_client.py", line 285 in read_server_response
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/interface/router_sock.py", line 27 in _read_message
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/wandb/sdk/interface/router.py", line 70 in message_loop
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 975 in run
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 1038 in _bootstrap_inner
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/threading.py", line 995 in _bootstrap

Thread 0x00007f7fb9b5c740 (most recent call first):
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line 45 in index_select
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line 67 in filter_node_store_
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/utils.py", line 136 in filter_hetero_data
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/node_loader.py", line 154 in filter_fn
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch_geometric/loader/base.py", line 36 in __next__
  File "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py", line 428 in train
  File "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py", line 856 in run
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346 in wrapper
  File "/dltraining/Neo4j_Gene_Trait_LP_V2_awsII.py", line 940 in <module>

Extension modules: yaml._yaml, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator
Extension modules: yaml._yaml
Extension modules: yaml._yaml, google._upb._message, numpy.core._multiarray_umath, numpy.core._multiarray_tests, psutil._psutil_linux, numpy.linalg._umath_linalg, psutil._psutil_posix, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.core._multiarray_umath, numpy.random._mt19937, numpy.random.mtrand, , numpy.random._philoxtorch._C, numpy.random._pcg64, torch._C._fft, numpy.core._multiarray_tests, numpy.random._sfc64, torch._C._linalg, numpy.random._generator, torch._C._nested, torch._C._nn, numpy.linalg._umath_linalg, torch._C._sparse, torch._C._special, numpy.fft._pocketfft_internal, , numpy.random._commongoogle._upb._message, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, psutil._psutil_linux, numpy.random._pcg64, psutil._psutil_posix, numpy.random._sfc64, numpy.random._generator, torch._C, , torch._C._fftgmpy2.gmpy2, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, gmpy2.gmpy2, scipy.sparse.linalg._isolve._iterative, scipy.linalg._fblas, scipy.linalg._flapack, , torch._Cscipy.linalg._cythonized_array_utils, scipy.linalg._flinalg, , torch._C._fftscipy.linalg._solve_toeplitz, scipy.linalg._matfuncs_sqrtm_triu, torch._C._linalg, scipy.linalg.cython_lapack, scipy.linalg.cython_blas, torch._C._nested, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, torch._C._nn, scipy.sparse.linalg._dsolve._superlu, , torch._C._sparsescipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, torch._C._special, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, , scipy._lib._ccallback_cscipy._lib.messagestream, scipy.spatial._qhull, scipy.sparse._sparsetools, scipy.spatial._voronoi, _csparsetools, scipy.spatial._distance_wrap, scipy.sparse._csparsetools, scipy.spatial._hausdorff, , scipy.sparse.linalg._isolve._iterativescipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.linalg._fblas, scipy.special._specfun, scipy.linalg._flapack, scipy.special._comb, scipy.linalg._cythonized_array_utils, scipy.special._ellip_harm_2, scipy.linalg._flinalg, scipy.linalg._solve_toeplitz, scipy.spatial.transform._rotation, , scipy.linalg._matfuncs_sqrtm_triuscipy.cluster._vq, , scipy.linalg.cython_lapackscipy.cluster._hierarchy, scipy.linalg.cython_blas, scipy.cluster._optimal_leaf_ordering, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, gmpy2.gmpy2, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.spatial._ckdtree, scipy._lib.messagestream, , sklearn.__check_build._check_buildscipy.spatial._qhull, , scipy.spatial._voronoisklearn.utils.murmurhash, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.special._ufuncs_cxx, numpy.linalg.lapack_lite, scipy.special._ufuncs, , scipy.ndimage._nd_imagescipy.special._specfun, scipy.special._comb, _ni_label, scipy.special._ellip_harm_2, scipy.ndimage._ni_label, scipy.spatial.transform._rotation, scipy.optimize._minpack2, scipy.cluster._vq, scipy.optimize._group_columns, scipy.cluster._hierarchy, scipy.optimize._trlib._trlib, scipy.cluster._optimal_leaf_ordering, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize.__nnls, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy._lib._ccallback_c, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.sparse._sparsetools, scipy.linalg._interpolative, sklearn.__check_build._check_build, scipy.optimize._bglu_dense, , _csparsetoolssklearn.utils.murmurhash, scipy.optimize._lsap, , scipy.optimize._directscipy.sparse._csparsetools, scipy.integrate._odepack, , numpy.linalg.lapack_lite, scipy.integrate._quadpackscipy.sparse.linalg._isolve._iterative, , scipy.ndimage._nd_imagescipy.integrate._vode, , scipy.integrate._dopscipy.linalg._fblas, _ni_label, scipy.integrate._lsoda, scipy.ndimage._ni_label, scipy.linalg._flapack, , scipy.special.cython_specialscipy.optimize._minpack2, , scipy.stats._stats, scipy.linalg._cythonized_array_utilsscipy.optimize._group_columns, scipy.stats.beta_ufunc, , scipy.optimize._trlib._trlib, scipy.linalg._flinalgscipy.stats._boost.beta_ufunc, , scipy.optimize._lbfgsbscipy.stats.binom_ufunc, scipy.linalg._solve_toeplitz, scipy.stats._boost.binom_ufunc, _moduleTNC, scipy.stats.nbinom_ufunc, scipy.optimize._moduleTNC, , scipy.stats._boost.nbinom_ufuncscipy.linalg._matfuncs_sqrtm_triu, scipy.optimize._cobyla, scipy.stats.hypergeom_ufunc, scipy.optimize._slsqp, , scipy.linalg.cython_lapackscipy.stats._boost.hypergeom_ufunc, scipy.optimize._minpack, scipy.stats.ncf_ufunc, , scipy.optimize._lsq.givens_eliminationscipy.linalg.cython_blas, scipy.stats._boost.ncf_ufunc, scipy.optimize._zeros, scipy.stats.ncx2_ufunc, scipy.linalg._matfuncs_expm, scipy.optimize.__nnls, scipy.stats._boost.ncx2_ufunc, scipy.optimize._highs.cython.src._highs_wrapper, , scipy.stats.nct_ufuncscipy.linalg._decomp_update, scipy.optimize._highs._highs_wrapper, scipy.stats._boost.nct_ufunc, scipy.optimize._highs.cython.src._highs_constants, scipy.stats.skewnorm_ufunc, , scipy.sparse.linalg._dsolve._superluscipy.optimize._highs._highs_constants, scipy.stats._boost.skewnorm_ufunc, , scipy.linalg._interpolativescipy.stats.invgauss_ufunc, scipy.sparse.linalg._eigen.arpack._arpack, , scipy.stats._boost.invgauss_ufuncscipy.optimize._bglu_dense, , scipy.optimize._lsapscipy.interpolate._fitpack, scipy.sparse.csgraph._tools, scipy.interpolate.dfitpack, scipy.optimize._direct, scipy.interpolate._bspl, , scipy.sparse.csgraph._shortest_pathscipy.integrate._odepack, scipy.interpolate._ppoly, scipy.integrate._quadpack, , scipy.interpolate.interpndscipy.sparse.csgraph._traversal, scipy.integrate._vode, scipy.interpolate._rbfinterp_pythran, scipy.integrate._dop, scipy.sparse.csgraph._min_spanning_tree, scipy.interpolate._rgi_cython, scipy.integrate._lsoda, , scipy.stats._biasedurnscipy.sparse.csgraph._flow, scipy.special.cython_special, scipy.stats._levy_stable.levyst, scipy.stats._stats, scipy.sparse.csgraph._matching, , scipy.stats._stats_pythranscipy.stats.beta_ufunc, , scipy.stats._boost.beta_ufunc, scipy._lib._uarray._uarrayscipy.sparse.csgraph._reordering, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats._statlib, scipy.stats.nbinom_ufunc, scipy.stats._mvn, , scipy.spatial._ckdtreescipy.stats._boost.nbinom_ufunc, scipy.stats._sobol, scipy.stats.hypergeom_ufunc, , scipy.stats._qmc_cy, scipy._lib.messagestreamscipy.stats._boost.hypergeom_ufunc, scipy.stats._rcont.rcont, scipy.stats.ncf_ufunc, scipy.spatial._qhull, scipy.stats._boost.ncf_ufunc, sklearn.utils._isfinite, scipy.stats.ncx2_ufunc, scipy.spatial._voronoi, , sklearn.utils._openmp_helpersscipy.stats._boost.ncx2_ufunc, , scipy.stats.nct_ufuncsklearn.utils._logistic_sigmoid, scipy.spatial._distance_wrap, scipy.stats._boost.nct_ufunc, sklearn.utils.sparsefuncs_fast, , scipy.stats.skewnorm_ufuncscipy.spatial._hausdorff, sklearn.preprocessing._csr_polynomial_expansion, scipy.stats._boost.skewnorm_ufunc, sklearn.utils._typedefs, , scipy.stats.invgauss_ufuncscipy.special._ufuncs_cxx, sklearn.utils._readonly_array_wrapper, scipy.stats._boost.invgauss_ufunc, sklearn.metrics._dist_metrics, scipy.special._ufuncs, scipy.interpolate._fitpack, sklearn.metrics.cluster._expected_mutual_info_fast, scipy.interpolate.dfitpack, , scipy.special._specfunsklearn.metrics._pairwise_distances_reduction._datasets_pair, scipy.interpolate._bspl, sklearn.utils._cython_blas, scipy.interpolate._ppoly, scipy.special._comb, sklearn.metrics._pairwise_distances_reduction._base, scipy.interpolate.interpnd, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, , scipy.interpolate._rbfinterp_pythranscipy.special._ellip_harm_2, sklearn.utils._heap, scipy.interpolate._rgi_cython, sklearn.utils._sorting, scipy.stats._biasedurn, sklearn.metrics._pairwise_distances_reduction._argkmin, scipy.spatial.transform._rotation, scipy.stats._levy_stable.levyst, sklearn.utils._vector_sentinel, , , scipy.stats._stats_pythranscipy.cluster._vqsklearn.metrics._pairwise_distances_reduction._radius_neighbors, , scipy._lib._uarray._uarraysklearn.metrics._pairwise_fast, scipy.cluster._hierarchy, scipy.stats._statlib, scipy.cluster._optimal_leaf_ordering, scipy.stats._mvn, , numba.core.typeconv._typeconvscipy.stats._sobol, scipy.stats._qmc_cy, numba._helperlib, scipy.stats._rcont.rcont, numba._dynfunc, sklearn.utils._isfinite, , sklearn.utils._openmp_helpersnumba._dispatcher, sklearn.utils._logistic_sigmoid, sklearn.utils.sparsefuncs_fast, numba.core.runtime._nrt_python, sklearn.preprocessing._csr_polynomial_expansion, sklearn.utils._typedefs, numba.np.ufunc._internal, sklearn.utils._readonly_array_wrapper, numba.experimental.jitclass._box, sklearn.metrics._dist_metrics, sklearn.metrics.cluster._expected_mutual_info_fast, , numba.mviewbufsklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, numba.types.itertools, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap (total: , 159, sklearn.utils._sortingsklearn.__check_build._check_build)
, sklearn.metrics._pairwise_distances_reduction._argkmin, , sklearn.utils.murmurhashsklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_fast, numpy.linalg.lapack_lite, , numba.core.typeconv._typeconvscipy.ndimage._nd_image, _ni_label, numba._helperlib, , scipy.ndimage._ni_labelnumba._dynfunc, scipy.optimize._minpack2, numba._dispatcher, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, numba.core.runtime._nrt_python, scipy.optimize._lbfgsb, , _moduleTNCnumba.np.ufunc._internal, scipy.optimize._moduleTNC, numba.experimental.jitclass._box, scipy.optimize._cobyla, scipy.optimize._slsqp, , numba.mviewbufscipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, numba.types.itertools, scipy.optimize._zeros, scipy.optimize.__nnls (total: 159)
, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.special.cython_special, scipy.stats._stats, scipy.stats.beta_ufunc, scipy.stats._boost.beta_ufunc, scipy.stats.binom_ufunc, scipy.stats._boost.binom_ufunc, scipy.stats.nbinom_ufunc, scipy.stats._boost.nbinom_ufunc, scipy.stats.hypergeom_ufunc, scipy.stats._boost.hypergeom_ufunc, scipy.stats.ncf_ufunc, scipy.stats._boost.ncf_ufunc, scipy.stats.ncx2_ufunc, scipy.stats._boost.ncx2_ufunc, scipy.stats.nct_ufunc, scipy.stats._boost.nct_ufunc, scipy.stats.skewnorm_ufunc, scipy.stats._boost.skewnorm_ufunc, scipy.stats.invgauss_ufunc, scipy.stats._boost.invgauss_ufunc, scipy.interpolate._fitpack, scipy.interpolate.dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._statlib, scipy.stats._mvn, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, sklearn.utils._isfinite, sklearn.utils._openmp_helpers, sklearn.utils._logistic_sigmoid, sklearn.utils.sparsefuncs_fast, sklearn.preprocessing._csr_polynomial_expansion, sklearn.utils._typedefs, sklearn.utils._readonly_array_wrapper, sklearn.metrics._dist_metrics, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_fast, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, numba.mviewbuf, numba.types.itertools (total: 159)
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 2566) of binary: /home/ubuntu/miniconda/envs/aws/bin/python
Traceback (most recent call last):
  File "/home/ubuntu/miniconda/envs/aws/bin/torchrun", line 33, in <module>
    sys.exit(load_entry_point('torch==2.0.1', 'console_scripts', 'torchrun')())
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/miniconda/envs/aws/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
=====================================================
Neo4j_Gene_Trait_LP_V2_awsII.py FAILED
-----------------------------------------------------
Failures:
[1]:
  time      : 2023-08-01_23:19:23
  host      : ip-172-31-4-202
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 2567)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2567
[2]:
  time      : 2023-08-01_23:19:23
  host      : ip-172-31-4-202
  rank      : 2 (local_rank: 2)
  exitcode  : -6 (pid: 2568)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2568
-----------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-08-01_23:19:23
  host      : ip-172-31-4-202
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 2566)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 2566
=====================================================

so some loaders have an extra batch

When you say extra batch, does this mean you will have two FWD pass instead of one?

Basically that means that some of the ranks have to run an extra iteration which seems to be what caused the timeouts and problems with synchronization. I had read that a minor discrepancy in batch sizes shouldn’t be a problem but that doesn’t appear to be the case. Once I ensured that all ranks had the same number of batches then they all completed. Thank you very much for you help with this!

1 Like

I seem to be running into a similar problem, where different batch sizes between ranks are causing some ranks to timeout. How do you ensure that the number of batch is the same for different ranks?