Hi, I am attempting to do distributed training on a multi-gpu instance using pytorch DDP. My model is a PyG GNN trained on a heterogenous graph. For some reason only one of the ranks (3) completes the script while the rest hang and appear to timeout for some reason. Because I am using a PyG hetero data object I am not able to use the sampler and instead chunk the data across ranks for the dataloader. I initiate training with the following command:
torchrun --nproc_per_node=4 gnn_model.py
Here is the final output:
rank 2 epoch 1
train rank 2
batch 0/1052
batch 1000/1052
val rank 3
test rank 3
epoch 1 complete on rank 3
Memory Usage On Rank 3
| ID | GPU | MEM |
-------------------
| 0 | 100% | 79% |
| 1 | 100% | 63% |
| 2 | 100% | 67% |
| 3 | 50% | 52% |
RAM Used (GB): 182.41599488
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806397 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806256 milliseconds before timing out.
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=67254, OpType=ALLREDUCE, Timeout(ms)=1800000) ran for 1806176 milliseconds before timing out.
Clean up run on rank: 3
Clean up completed on rank 3
ip-172-31-4-202:2515:2626 [3] NCCL INFO [Service thread] Connection closed by localRank 3
ip-172-31-4-202:2515:2515 [0] NCCL INFO comm 0x55986c074da0 rank 3 nranks 4 cudaDev 3 busId 1e0 - Abort COMPLETE
STOPPING INSTANCE
Below is my code:
#!/usr/bin/env python
# coding: utf-8
import resource
import boto3
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
import wandb
wandb.login()
import torch
import time
import numpy as np
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.loader import NeighborLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from sklearn.metrics import roc_auc_score
#from torch.utils.tensorboard import SummaryWriter # Writer will output to ./runs/ directory by default.
from datetime import datetime
import gc
from GPUtil import showUtilization as gpu_usage
import torch.nn as nn
import random, requests, glob, time
from torch.utils.checkpoint import checkpoint
from sklearn.metrics import roc_auc_score
import torch_geometric.transforms as T
from torch_geometric.nn import to_hetero, norm
from torch_geometric.nn.conv import GATv2Conv
from torch_geometric.transforms import ToUndirected, RandomLinkSplit, NormalizeFeatures, Compose
from torch.distributed.elastic.multiprocessing.errors import record
import torch.distributed as dist
import os
import random
import psutil
import torch.multiprocessing as mp
os.environ['NCCL_DEBUG']='INFO'
# os.environ['NCCL_LL_THRESHOLD']=0
def stop_instance(instance_id):
ec2 = boto3.resource('ec2', region_name='us-west-2')
instance = ec2.Instance(instance_id)
response = instance.stop()
return response
def get_instance_id():
response = requests.get('http://169.254.169.254/latest/meta-data/instance-id')
return response.text
def set_seed(seed: int = 42) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
print(f"Random seed set as {seed}")
def free_gpu_cache(empty_cache=True, print_gpu=False, rank=0):
if empty_cache:
gc.collect()
torch.cuda.empty_cache()
if print_gpu:
print("Memory Usage On Rank ", rank)
gpu_usage()
print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)
aws_dict = {
'heads': [1],
'hidden_channels': [128],
'output_channels': [32],
'decode_channels': [32],
'num_conv': [2],
'num_lin': [1]
}
num_neighbors = [10, 3]
disjoint_train_ratio = 0.5
epochs = 2
dropout = 0.7
lr = 0.001
aggr = 'sum'
batch_size = 32
new_data = False
instance_stop = True
class CriterionParallel(nn.Module):
def __init__(self, module):
super(CriterionParallel, self).__init__()
self.module = module
if not isinstance(module, nn.Module):
raise ValueError("module must be a nn.Module")
if torch.cuda.device_count() == 0:
raise ValueError("No GPUs available")
self.device_ids = list(range(torch.cuda.device_count()))
def forward(self, inputs, *targets):
targets = tuple(t.to(inputs.device) for t in targets)
return self.module(inputs, *targets)
def can_convert_to_int(s):
try:
int(s)
return True
except ValueError:
return False
def load_checkpoint_model(checkpoint_path, hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, data, device, model_filename=None):
if model_filename:
checkpoint_epoch_path = os.path.join(checkpoint_path, model_filename)
if not os.path.isfile(checkpoint_epoch_path):
print(f"Error: File {model_filename} not found.")
return None, None # or handle this in another way suitable for your application
else:
list_of_checkpoint_files = glob.glob(os.path.join(checkpoint_path, '*'))
# Make sure the file is a PyTorch file (.pt) and can extract an integer from filename.
list_of_checkpoint_files = [f for f in list_of_checkpoint_files if f.endswith('.pt') and can_convert_to_int(f.split("_")[-1].split('.')[0])]
if not list_of_checkpoint_files:
print("Error: No valid checkpoint files found.")
return None, None # or handle this in another way suitable for your application
checkpoint_epoch_number = max([int(file.split("_")[-1].split('.')[0]) for file in list_of_checkpoint_files])
checkpoint_epoch_path = os.path.join(checkpoint_path, f'checkpoint_AWS_RUN_0_epoch_{checkpoint_epoch_number}.pt')
# Instantiate the model with the correct parameters and load onto the device.
resume_model = Model(hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, data).to(device)
# Load the state dict onto the device.
resume_model.load_state_dict(torch.load(checkpoint_epoch_path, map_location=device))
return resume_model, checkpoint_epoch_number
def data_process(path, rank):
# Load the hetero data object created in Data_formatting.ipynb
print('loading data')
data = torch.load(path, map_location=torch.device('cpu'))
print('data loaded')
free_gpu_cache(print_gpu=True, rank=rank)
data = ToUndirected(merge=False)(data)
print('data to undirected')
#Perform a link-level split into training, validation, and test edges as well as node normalization.
transform = Compose([NormalizeFeatures(['x','edge_attr']), RandomLinkSplit(
num_val=0.2,
num_test=0.1,
split_labels=True,
is_undirected=True,
edge_types=[('node_a', 'edge_ab', 'node_b'), ('node_c', 'edge_cd', 'node_d')],
rev_edge_types=[('node_b', 'rev_edge_ab', 'node_a'), ('node_d', 'rev_edge_cd', 'node_c')],
disjoint_train_ratio=disjoint_train_ratio
)])
train_data, val_data, test_data = transform(data)
print('data transformed')
del data
free_gpu_cache(print_gpu=True, rank=rank)
return train_data, val_data, test_data
def data_loaders(batch_size, train_data, val_data, test_data, num_neighbors, rank, world_size):
train_idx = torch.chunk(torch.arange(train_data['node_a'].num_nodes, dtype=torch.long), world_size)[rank]
val_idx = torch.chunk(torch.arange(val_data['node_a'].num_nodes, dtype=torch.long), world_size)[rank]
test_idx = torch.chunk(torch.arange(test_data['node_a'].num_nodes, dtype=torch.long), world_size)[rank]
dist_batch_size = batch_size // world_size
num_workers = 0
persistent_workers = False
print('indices created')
train_loader = NeighborLoader(
train_data,
batch_size=dist_batch_size,
shuffle=True,
directed=False,
num_neighbors=num_neighbors,
input_nodes=('node_a', train_idx),
num_workers=num_workers,
persistent_workers=persistent_workers,
drop_last=True)
val_loader = NeighborLoader(
val_data,
batch_size=dist_batch_size,
directed=False,
num_neighbors=num_neighbors,
input_nodes=('node_a', val_idx),
num_workers=num_workers,
persistent_workers=persistent_workers,
drop_last=True)
test_loader = NeighborLoader(
test_data,
batch_size=dist_batch_size,
directed=False,
num_neighbors=num_neighbors,
input_nodes=('node_a', test_idx),
num_workers=num_workers,
persistent_workers=persistent_workers,
drop_last=True)
del train_data, val_data, test_data, train_idx, val_idx, test_idx, dist_batch_size, num_workers, persistent_workers
free_gpu_cache(print_gpu=True, rank=rank)
print('loaders created')
return train_loader, val_loader, test_loader
def process_edges(batch, device, rank=0):
for edge in batch.edge_types:
num_edges = batch[edge].edge_index.shape[1]
shuffle_target = torch.randperm(num_edges).to(device)
neg_targets = batch[edge].edge_index[1, shuffle_target]
neg_edge_index = torch.clone(batch[edge].edge_index).to(device)
neg_edge_index[1, :] = neg_targets
batch[edge].neg_edge_label_index = neg_edge_index
batch[edge].neg_edge_label = torch.zeros(num_edges, dtype=torch.float, device=device)
batch[edge].pos_edge_label = torch.ones(num_edges, dtype=torch.float, device=device)
batch[edge].pos_edge_label_index = batch[edge].edge_index
edge_label_index_dict, edge_label_dict = {}, {}
for edge in batch.edge_index_dict.keys():
edge_label_index_dict[edge] = torch.cat([batch[edge].pos_edge_label_index, batch[edge].neg_edge_label_index], dim=1).to(device)
edge_label_dict[edge] = torch.cat([batch[edge].pos_edge_label, batch[edge].neg_edge_label], dim=0).to(device)
return edge_label_index_dict, edge_label_dict
class GNNEncoder(nn.Module):
def __init__(self, hidden_channels, output_channels, dropout, heads, num_layers):
super().__init__()
self.layers = []
self.batch_norms = []
for i in range(num_layers):
if i == 0:
channels = hidden_channels
elif num_layers == 3 and i == 1:
channels = hidden_channels
else:
channels = output_channels
layer = GATv2Conv((-1, -1), channels, add_self_loops=False, edge_dim=-1, heads=heads)
self.add_module('layer_{}'.format(i), layer)
self.layers.append(layer)
batch_norm = norm.BatchNorm(heads * channels, allow_single_element=True)
self.add_module('batch_norm_{}'.format(i), batch_norm)
self.batch_norms.append(batch_norm)
self.dropout = torch.nn.Dropout(p=dropout)
self.relu = torch.nn.ReLU()
def forward(self, x, edge_index, edge_attr):
for i in range(len(self.layers)):
x = checkpoint(self.layers[i], x, edge_index, edge_attr) # Use gradient checkpointing
x = self.batch_norms[i](x)
if i != len(self.layers)-1:
x = self.relu(x)
x = self.dropout(x)
return x
class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels, output_channels, decode_channels, dropout, heads, num_lin, num_conv):
super().__init__()
self.layers = []
self.batch_norms = []
if num_conv == 1:
output_channels = hidden_channels
if num_lin == 1:
decode_channels = 2 * heads * hidden_channels
if num_lin == 1:
decode_channels = 2 * heads * output_channels
for i in range(num_lin):
if i == 0 and num_lin != 1:
layer = Linear(2 * heads * output_channels, decode_channels)
self.add_module('layer_{}'.format(i), layer)
self.layers.append(layer)
elif num_lin == 3 and i == 1:
layer = Linear(decode_channels, decode_channels)
self.add_module('layer_{}'.format(i), layer)
self.layers.append(layer)
else:
layer = Linear(decode_channels, 1)
self.add_module('layer_{}'.format(i), layer)
self.layers.append(layer)
self.relu = torch.nn.ReLU()
self.dropout = torch.nn.Dropout(p=dropout)
def forward(self, z_dict, edge_label_index):
forward_dict, forward_dict_rest = {}, {}
for i, edge in enumerate(edge_label_index.keys()):
if edge[1] == 'edge_ab' or edge[1] == 'rev_edge_ab':
row, col = edge_label_index[edge]
forward_dict[f'z_forward{i}'] = torch.cat([z_dict[edge[0]][row], z_dict[edge[2]][col]], dim=-1)
else:
row, col = edge_label_index[edge]
forward_dict_rest[f'z_forward{i}'] = torch.cat([z_dict[edge[0]][row], z_dict[edge[2]][col]], dim=-1)
z = torch.cat([x for x in forward_dict.values()], dim=0)
z_rest = torch.cat([x for x in forward_dict_rest.values()], dim=0)
for i in range(len(self.layers)):
z = self.layers[i](z)
if i != len(self.layers)-1:
z = self.relu(z)
z = self.dropout(z)
for i in range(len(self.layers)):
z_rest = self.layers[i](z_rest)
if i != len(self.layers)-1:
z_rest = self.relu(z_rest)
z_rest = self.dropout(z_rest)
return z.view(-1), z_rest.view(-1)
class Model(torch.nn.Module):
def __init__(self, hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, metadata):
super().__init__()
self.encoder = GNNEncoder(hidden_channels, output_channels, dropout, heads, num_conv)
self.encoder = to_hetero(self.encoder, metadata, aggr=aggr)
self.decoder = EdgeDecoder(hidden_channels, output_channels, decode_channels, dropout, heads, num_lin, num_conv)
del metadata
def forward(self, x_dict, edge_index_dict, edge_label_index, edge_attr_dict):
z_dict = self.encoder(x_dict, edge_index_dict, edge_attr_dict)
return self.decoder(z_dict, edge_label_index)
def train(data_loader, model, optimizer, criterion, batch_size, rank, device, batch_run=False):
model.train()
total_examples = total_loss = 0
roc_auc_sum = 0.0
roc_auc_count = 0
skip = False
for i, batch in enumerate(data_loader):
#Check if the edge type exists in the batch
edge_index = batch[('node_a', 'edge_ab', 'node_b')].get('edge_index')
if edge_index is not None and edge_index.size(1) == 0:
skip = True
if i % 1000 == 0 and rank == 0:
print(f'batch {i}/{len(data_loader)}')
batch = batch.to(device=device)
edge_label_index_dict, edge_label_dict = process_edges(batch, device, rank)
raw_pred, raw_pred_rest = model(batch.x_dict, batch.edge_index_dict, edge_label_index_dict, batch.edge_attr_dict)
tar = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' in k or 'rev_edge_ab' in k], dim=0)
tar_rest = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' not in k and 'rev_edge_ab' not in k], dim=0)
tar = tar.to(device=device)
tar_rest = tar_rest.to(device=device)
loss_gt = criterion(torch.nn.functional.relu(raw_pred), tar)
loss_rest = criterion(torch.nn.functional.relu(raw_pred_rest), tar_rest)
lmda = 0.99
loss = lmda * loss_gt + (1 - lmda) * loss_rest
loss_val = loss.item()
loss.backward()
if (i + 1) % 4 == 0: #gradient accumulation
optimizer.step()
optimizer.zero_grad()
# Empty edge
if skip:
loss_val = 0
total_examples += batch_size
total_loss += float(loss_val) * batch_size
# Compute ROC AUC for this batch and add to sum
preds = torch.sigmoid(raw_pred.detach()).cpu().numpy()
tars = tar.cpu().numpy().astype(int)
# if len(preds) > 0 and len(tars) > 0 and not np.isnan(preds).any() and not np.isnan(tars).any():
if len(preds) > 0 and len(tars) > 0 and skip is False:
roc_auc_sum += roc_auc_score(tars, preds)
roc_auc_count += 1
if batch_run == True:
break
# Compute the average ROC AUC
roc_auc_avg = roc_auc_sum / roc_auc_count if roc_auc_count > 0 else 0.0
loss_val= total_loss / total_examples
return loss_val, roc_auc_avg
@torch.no_grad()
def test(data_loader, model, criterion, batch_size, rank, device, batch_run=False):
model.eval()
total_examples = total_loss = 0
roc_auc_sum = 0.0
roc_auc_count = 0
skip = False
for i, batch in enumerate(data_loader):
# Check if the edge type exists in the batch
edge_index = batch[('node_a', 'edge_ab', 'node_b')].get('edge_index')
if edge_index is not None and edge_index.size(1) == 0:
skip = True
if i % 1000 == 0 and rank == 0:
print(f'batch {i}/{len(data_loader)}')
batch = batch.to(device=device)
edge_label_index_dict, edge_label_dict = process_edges(batch, device, rank)
raw_pred, raw_pred_rest = model(batch.x_dict, batch.edge_index_dict, edge_label_index_dict, batch.edge_attr_dict)
tar = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' in k or 'rev_edge_ab' in k], dim=0)
tar_rest = torch.cat([v for k, v in edge_label_dict.items() if 'edge_ab' not in k and 'rev_edge_ab' not in k], dim=0)
tar = tar.to(device=device)
tar_rest = tar_rest.to(device=device)
loss_gt = criterion(torch.nn.functional.relu(raw_pred), tar)
loss_rest = criterion(torch.nn.functional.relu(raw_pred_rest), tar_rest)
lmda = 0.99
loss = lmda * loss_gt + (1 - lmda) * loss_rest
loss_val = loss.item()
# Empty edge
if skip:
loss_val = 0
total_examples += batch_size
total_loss += float(loss_val) * batch_size
# Compute ROC AUC for this batch and add to sum
preds = torch.sigmoid(raw_pred.detach()).cpu().numpy()
tars = tar.cpu().numpy().astype(int)
if len(preds) > 0 and len(tars) > 0 and skip is False:
roc_auc_sum += roc_auc_score(tars, preds)
roc_auc_count += 1
if batch_run == True:
break
# Compute the average ROC AUC
roc_auc_avg = roc_auc_sum / roc_auc_count if roc_auc_count > 0 else 0.0
loss_val = total_loss / total_examples
return loss_val, roc_auc_avg
def setup():
# initialize the process group
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
@record
def run(new_data, epochs):
rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
set_seed()
free_gpu_cache(print_gpu=True, rank=rank)
print(f"Running basic DDP example on rank {rank}.")
setup()
volume_mount_dir = ''
dataset_path = os.path.join(volume_mount_dir, 'datasets/graph_data_v2.pt')
checkpoint_path = os.path.join(volume_mount_dir, 'checkpoints/')
device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu')
print("Device is ", device)
if rank == 0 and new_data==True:
# Only process the data in the first process.
train_data, val_data, test_data = data_process(dataset_path, rank)
torch.save((train_data, val_data, test_data), 'datasets/processed_data.pt')
print("data sets saved")
# Create data loaders on rank 0 for testing or other purposes, but they won't be saved or shared.
train_loader, val_loader, test_loader = data_loaders(batch_size, train_data, val_data, test_data, num_neighbors, rank, world_size)
train_batch = next(iter(train_loader))
train_batch = train_batch.to(device=device)
torch.save((train_batch), 'datasets/batch_data.pt')
print("data saved")
dist.barrier()
if rank != 0 or new_data==False:
train_data, val_data, test_data = torch.load('datasets/processed_data.pt')
free_gpu_cache(print_gpu=True, rank=rank)
train_batch = torch.load('datasets/batch_data.pt')
print("batch data loaded")
free_gpu_cache(print_gpu=True, rank=rank)
dist.barrier()
if rank == 0:
print("all data loaded")
# Recreate the data loaders in each process.
train_loader, val_loader, test_loader = data_loaders(batch_size, train_data, val_data, test_data, num_neighbors, rank, world_size)
if rank == 0:
print("dataloaders created")
free_gpu_cache(print_gpu=True, rank=rank)
train_batch = train_batch.to(rank)
edge_label_index_dict, _ = process_edges(train_batch, device, rank)
if rank == 0:
print('edge dictionaries created')
free_gpu_cache(print_gpu=True, rank=rank)
dist.barrier()
print(f"# Train Batches Rank {rank}: {len(train_loader)}\n")
print(f"# Val Batches: Rank {rank}: {len(val_loader)}\n")
print(f"# Test Batches: Rank {rank}: {len(test_loader)}\n")
for i in range(len(aws_dict['heads'])):
run_name = f'AWS_RUN_{i}'
print(f'{run_name} on rank {rank}\n')
if rank == 0:
wandb.init(name=run_name,
notes='aws dist',
project="gene_disease_dist_aws_manual_v2",
entity="gurugecl")
num_lin = aws_dict['num_lin'][i]
num_conv = aws_dict['num_conv'][i]
heads = aws_dict['heads'][i]
decode_channels = aws_dict['decode_channels'][i]
hidden_channels = aws_dict['hidden_channels'][i]
output_channels = aws_dict['output_channels'][i]
if rank==0:
print("num_lin: ", num_lin)
print("num_conv: ", num_conv)
print("heads: ", heads)
print("hidden_channels: ", hidden_channels)
print("decode_channels: ", decode_channels)
print("output_channels: ", output_channels)
model_filename = f"completed_model_AWS_RUN_{i}.pt" # change this to the name of the model you want to load
model_file_path = os.path.join(checkpoint_path, model_filename)
if os.path.isfile(model_file_path):
print("EXISTING MODEL")
model, start_epoch = load_checkpoint_model(model_file_path, hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, train_batch.metadata(), device)
model = model.to(device)
else:
if rank == 0:
print("NEW MODEL")
start_epoch = 0
model = Model(hidden_channels, output_channels, decode_channels, dropout, heads, num_conv, num_lin, aggr, train_batch.metadata()).to(device)
if rank == 0:
print('model loaded')
# Perform dummy forward pass for correct initialization
with torch.no_grad():
model.forward(train_batch.x_dict, train_batch.edge_index_dict, edge_label_index_dict, train_batch.edge_attr_dict)
if rank == 0:
print("dummy forward")
dist.barrier()
# Apply DDP
model = DDP(model, device_ids=[rank], output_device=rank) #, gradient_as_bucket_view=True)
if rank == 0:
print("DDP")
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.BCEWithLogitsLoss()
# compute loss inside of DataParallel due to imbalance
if torch.cuda.device_count() > 1:
criterion = CriterionParallel(criterion)
torch.autograd.set_detect_anomaly(True)
start_time = time.time()
dist.barrier()
for epoch in np.arange(start_epoch, epochs): #10
print(f'rank {rank} epoch {epoch}\n')
print(f'train rank {rank}\n')
if rank==0:
epoch_start_time = time.time()
train_loss, train_roc = train(train_loader, model, optimizer, criterion, batch_size, rank, device, batch_run=False)
print(f'val rank {rank}\n')
val_loss, val_roc = test(val_loader, model, criterion, batch_size, rank, device, batch_run=False)
print(f'test rank {rank}\n')
test_loss, test_roc = test(test_loader, model, criterion, batch_size, rank, device, batch_run=True)
comb_roc = train_roc * (1/3) + val_roc * (2/3)
print(f'epoch {epoch} complete on rank {rank}\n')
if rank == 0:
epoch_end_time = time.time()
print(f'Epoch: {epoch:03d}, LOSS: (Train: {train_loss:.4f}, Val: {val_loss:.4f}, Test: {test_loss:.4f}), ROC: (Train: {train_roc:.4f}, '
f'Val: {val_roc:.4f}, Test: {test_roc:.4f}, Comb: {comb_roc:.4f})')
epoch_time = epoch_end_time - epoch_start_time
wandb.log({"Epoch": epoch,
"Train Loss":train_loss,
"Train ROC": train_roc,
"Val Loss": val_loss,
"Val ROC": val_roc,
"Test Loss": test_loss,
"Test ROC": test_roc,
"Comb_ROC": comb_roc,
"Epoch Time": epoch_time})
torch.save(model.state_dict(), f'checkpoints/checkpoint_{run_name}_epoch_{epoch}.pt')
print("model saved")
print(f'Run: {run_name} Epoch: {epoch} Execution time: {epoch_time} seconds')
with open(f'checkpoints/time_{run_name}.txt', 'a') as f: # Open file in append mode
f.write(f'Run: {run_name} Epoch: {epoch} Time: {epoch_time} seconds\n')
dist.barrier() # for each epoch
free_gpu_cache(print_gpu=True, rank=rank)
if rank == 0:
end_time = time.time()
total_time = end_time - start_time # Calculate execution time
print(f'Total execution time: {total_time} seconds')
with open(f'checkpoints/time_{run_name}.txt', 'a') as f: # Open file in append mode
f.write(f'Run: {run_name} Total Time: {total_time} seconds\n')
torch.save(model, f'checkpoints/completed_model_{run_name}.pt')
wandb.init().finish()
dist.barrier() # for each aws config
print(f"Clean up run on rank: {rank}", flush=True)
cleanup()
print(f"Clean up completed on rank {rank}", flush=True)
if __name__ == '__main__':
world_size = torch.cuda.device_count()
print('Using', world_size, 'GPUs!')
run(new_data, epochs)
if instance_stop:
print("STOPPING INSTANCE")
instance_id = get_instance_id()
stop_instance(instance_id)