Initialization of Nodes in mps

Hello all,

I’ve been running some GAN tests both on my local machine (Apple M1 with mps) and on a remote server (with cuda) and I recently realized that the very same network can generate sufficient MNIST digits on my local with mps whereas it fails to do so with remote cuda.

How are nodes initialized for mps build of pytorch? I ask this so that I can apply the same initialization of mps to the test I run on the server.

FYI:
torch version my local (successful):
torch 1.13.0.dev20220708
torchaudio 0.13.0.dev20220708
torchvision 0.14.0.dev20220708

torch version on remote server (unsuccessful):
torch 1.13.1
torchaudio 0.13.1
torchvision 0.14.1

In case the code of the network is needed, here it is:

#!/usr/bin/env python
# coding: utf-8



import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import random
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from datetime import datetime
import os
from matplotlib.pyplot import figure
import shutil

figure(figsize=(12, 10), dpi=120)

plt.rcParams['image.cmap'] = 'gray'
batch_size = 100

NUMBER_OF_CLASSES = 10
NUMBER_OF_CLUSTER_CENTROIDS = 32
NO_TRAINING = False

device=torch.device('mps')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

def augment_data(data):

    transform = transforms.Compose([
                                            transforms.Normalize((0.5,), (0.5,))])
    transform_aug = transforms.Compose([transforms.RandomResizedCrop(size=28,scale=(0.2, 1.0)),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.Normalize((0.5,), (0.5,))])
        
    return transform(data),  transform_aug(data)
    
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()

        self.main = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
        )
        
        self.real_or_fake_layer = nn.Sequential(
            nn.Linear(128, 1),
        )

        self.image_feature_layer = nn.Sequential(
            nn.Linear(128, NUMBER_OF_CLUSTER_CENTROIDS),
            nn.Sigmoid() 
        )
        self.centroids = nn.Linear(NUMBER_OF_CLASSES, NUMBER_OF_CLUSTER_CENTROIDS)
    
    def forward(self, input, eye):
        common_part = self.main(input)
        real_or_fake = self.real_or_fake_layer(common_part)
        image_feature = self.image_feature_layer(common_part)
        cluster_centroid = self.centroids(eye)

        return real_or_fake, image_feature, cluster_centroid


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()

        self.main = nn.Sequential(
        nn.Linear(138,512),
        nn.ReLU(),
        nn.Linear(512, 1024),
        nn.ReLU(),
        nn.Linear(1024, 784),
        nn.Tanh(),            
    )

    def forward(self, input):
        step = self.main(input)
        return step

def show_images(images, epoch,folder_path):
    sqrtn = int(np.ceil(np.sqrt(images.shape[0])))

    for index, image in enumerate(images):

        plt.subplot(sqrtn, sqrtn, index+1)
        path = folder_path + str(epoch) + ".png"
        plt.imshow(image.reshape(28, 28))
        plt.savefig(path)

transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST('./MNIST_DATA/train/', train=True, download=True,transform=transform)
test_set = datasets.MNIST('./MNIST_DATA/test/', train=False, download=True,transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False)

CrosEntr = nn.CrossEntropyLoss()

from scipy.optimize import linear_sum_assignment

def get_stat(generated_label, num_of_gt_labels, gt_labels):
    c_stat = np.zeros([num_of_gt_labels,num_of_gt_labels])
    for i in range(len(gt_labels)):
        gt_idx = int(gt_labels[i])
        c_stat[gt_idx][generated_label[i]] += 1
    return c_stat

def get_match(stat):
    _, col_ind = linear_sum_assignment(stat.max()-stat)
    return col_ind

def get_acc(stat, col_ind, over):
    tot = 0
    for i in range(stat.shape[0]):
        tot += stat[i][col_ind[i]]
    return tot/(np.sum(stat)/over)

def get_nmi(stat):
    n,m = stat.shape
    pij = stat/np.sum(stat)
    pi = np.sum(pij, 1)
    pj = np.sum(pij, 0)
    enti = sum([-pi[i]*np.log2(pi[i]+1e-6) for i in range(n)])
    entj = sum([-pj[i]*np.log2(pj[i]+1e-6) for i in range(m)])
    mi = 0
    for i in range(n):
        for j in range(m):
            mi += pij[i][j]*(np.log2(pij[i][j]/(pi[i]*pj[j]+1e-6)+1e-6))
    return mi/max(enti, entj)

epochs=200
lr =  0.0005
TEMP = 0.1
HYPER_PARA_ADV = 0.5
HYPER_PARA_INFO_LOSS = 5
HYPER_PARA_CONTR = 0.2
HYPER_PARA_ENTROPY = 0.2

now = datetime.now()
print("now =", now)
dt_string = now.strftime("%Y_%m_%d_%H%M%S")
print("date and time =", dt_string)

# Check whether the specified path exists or not
isExist = os.path.exists(dt_string)
if not isExist and not NO_TRAINING:
    # Create a new directory because it does not exist
    save_folder_path = "./results/" + dt_string + "/"
    os.makedirs(save_folder_path[2:-1])
    shutil.copy2("Sandbox_CIFAR10.py",save_folder_path[2:-1])
# TRAINING

#############################################################################################




Gen = generator().to(device)
Dis = discriminator().to(device)
print(Gen)
print(Dis)
optimizerG = optim.Adam(Gen.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = optim.Adam(Dis.parameters(), lr=lr, betas=(0.5, 0.999))
best_acc = 0
best_nmi = 0

#############################################################################################
eye = torch.eye(NUMBER_OF_CLASSES).to(device)
for epoch in range(epochs):
    epoch += 1
    if NO_TRAINING:
        print("There won't be any training, go to top cell to change NO_TRAINING to False")
        break
    for times, data in enumerate(train_loader):
            times += 1
            # 1. Prepare Inputs
            real_inputs, aug_real_inputs = augment_data(data[0])
            real_inputs, aug_real_inputs = real_inputs.to(device), aug_real_inputs.to(device)
            real_inputs, aug_real_inputs = real_inputs.view(-1, 784), aug_real_inputs.view(-1,784)



            noise = torch.randn(real_inputs.shape[0], 128) 
            noise = noise.to(device)                                                                                
            
            rand_c = torch.zeros(real_inputs.shape[0], NUMBER_OF_CLASSES).to(device)      
            rand_idx = [i for i in range(NUMBER_OF_CLASSES)]                                        
            random.shuffle(rand_idx)                                                                                  
            for i, element in enumerate(rand_c):
                element[rand_idx[i%10]] = 1
            noise = torch.cat((noise, rand_c),1)
            fake_inputs = Gen(noise)
            
            # 2. Train Discriminator
                # 2.1 Pass Images
            optimizerD.zero_grad()

            real_outputs, real_img_feat, real_class_emb = Dis(real_inputs, eye) 
            fake_outputs, fake_img_feat, fake_class_emb = Dis(fake_inputs, eye) 

                # 2.2 Adversarial Loss

            real_adv_loss = HYPER_PARA_ADV* (nn.ReLU()(1 - real_outputs).mean()) 
            fake_adv_loss_D = HYPER_PARA_ADV*(nn.ReLU()(1 + fake_outputs).mean()) 


                # 2.3 Info Loss
            f = F.normalize(fake_img_feat, p=2, dim=1)
            c = F.normalize(fake_class_emb, p=2, dim=1)
            class_dist = torch.cat([torch.matmul(f, c[i].unsqueeze(-1))/TEMP for i in range(NUMBER_OF_CLASSES)],1)
            info_loss = HYPER_PARA_INFO_LOSS* CrosEntr(class_dist, torch.argmax(rand_c, 1))

            
            #     # 2.4 Real Image Contrastive Loss
            # if add_con_loss:
            labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)
            labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
            labels.to(device)

            _,real_aug_img_feat, _ = Dis(aug_real_inputs, eye) 
            feat = torch.cat([real_img_feat, real_aug_img_feat], 0)
            feat = F.normalize(feat, p=2, dim=1)
            similarity_matrix = torch.matmul(feat, feat.T)
            similarity_matrix = similarity_matrix.detach().cpu()
        
            mask = torch.eye(feat.shape[0], dtype=torch.bool)
            labels = labels[~mask].view(labels.shape[0], -1)
            similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

            positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1).to(device)
            negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1).to(device)
            logits = torch.cat([positives, negatives], dim=1)
            labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
            logits = logits/TEMP


            contrastive_loss = HYPER_PARA_CONTR* CrosEntr(logits, labels)

            #     # 2.5 Entropy Regularizations
            rr = F.normalize(real_img_feat, p=2, dim=1)

            class_dist_real = torch.cat([torch.matmul(rr, c[i].unsqueeze(-1))/TEMP for i in range(NUMBER_OF_CLASSES)], 1)
            class_dist_real_sm = F.softmax(class_dist_real, 1)
            entropy_reg_1 = -1*HYPER_PARA_ENTROPY*(class_dist_real_sm*torch.log(class_dist_real_sm)).sum(1).mean()
            entropy_reg_2 = 1*HYPER_PARA_ENTROPY*(class_dist_real_sm.mean(0)*torch.log(class_dist_real_sm.mean(0))).sum()


            """ZERO GRAD HAS BEEN MOVED UP"""
                # 2.6 Backward Pass


            D_loss =  real_adv_loss + fake_adv_loss_D + info_loss + contrastive_loss + entropy_reg_1 + entropy_reg_2 
            D_loss.backward()
            optimizerD.step()

            noise = torch.randn(real_inputs.shape[0], 128) 
            noise = noise.to(device)                                                                                

            rand_c = torch.zeros(real_inputs.shape[0], NUMBER_OF_CLASSES).to(device)  
            rand_idx = [i for i in range(NUMBER_OF_CLASSES)]                                           
            random.shuffle(rand_idx)                                                                                
            for i, element in enumerate(rand_c):
                element[rand_idx[i%10]] = 1
            noise = torch.cat((noise, rand_c),1)
            
            fake_inputs = Gen(noise)

            optimizerG.zero_grad()
            # 4 Train Generator
            fake_outputs, fake_img_feat, fake_class_emb_2 = Dis(fake_inputs, eye) 

                # 4.1 Adversarial Loss 

            fake_adv_loss_G =  -HYPER_PARA_ADV*fake_outputs.mean()
            
                # 4.2 Info Loss
            f_2 = F.normalize(fake_img_feat, p=2, dim=1)
            c_2 = F.normalize(fake_class_emb_2, p=2, dim=1)
            class_dist_2 = torch.cat([torch.matmul(f_2, c_2[i].unsqueeze(-1))/TEMP for i in range(NUMBER_OF_CLASSES)],1)
            info_loss_2 = HYPER_PARA_INFO_LOSS * CrosEntr(class_dist_2, torch.argmax(rand_c, 1))
                

            G_loss = fake_adv_loss_G + info_loss_2
            
            """ZERO HAS BEEN MOVED TO THE TOP"""
            G_loss.backward()
            optimizerG.step()

    
    # TEST PERFORMANCE
    pred_c = []
    pred_c_cpu = []
    real_c = []
    with torch.no_grad():
        for image, label in test_loader:
 
                real_img_test = image.to(device).view(-1, 784)
                _, feat_test, class_emb_test = Dis(real_img_test, eye)
                f_test = F.normalize(feat_test, p=2, dim=1)
                c_test = F.normalize(class_emb_test, p=2, dim=1)

                class_dist_test = torch.cat([torch.from_numpy(np.dot(f_test.cpu(),c_test[i].cpu())).float().to(device).unsqueeze(-1)/TEMP for i in range(NUMBER_OF_CLASSES)],1)
                # class_dist_candas = torch.cat([torch.matmul(f, c[i]).unsqueeze(-1)/TEMP for i in range(NUMBER_OF_CLASSES)], 1) ### WRONG: FOR SOME REASON IT CALCULATES WRONG

                pred_c += list(torch.argmax(class_dist_test, 1))
                pred_c_cpu += list(torch.argmax(class_dist_test.cpu(), 1))
                real_c += list(label.cpu().numpy())
                
        c_table = get_stat(pred_c, NUMBER_OF_CLASSES, real_c)
        idx_map = get_match(c_table)
        cur_acc = get_acc(c_table, idx_map, 1)
        cur_nmi = get_nmi(c_table[:10,:])
    
        if cur_acc > best_acc:
            best_acc = cur_acc
            best_nmi = cur_nmi
            torch.save(Gen, save_folder_path + 'Best_Generator.pth')
            torch.save(Dis, save_folder_path + 'Best_Discriminator.pth')

    with open(save_folder_path+"0_Test_Performance.txt",'a') as f:
        f.write("Epoch: {}, Test Accuracy: {:.3f}, Test NMI: {:.3f}\n".format(epoch, cur_acc, cur_nmi))

    imgs_numpy = (fake_inputs.data.cpu().numpy()+1.0)/2.0
    print(imgs_numpy.size)
    print(imgs_numpy[:16].size)
    show_images(imgs_numpy[:16], epoch, save_folder_path)

I assume nodes refer to trainable parameters?
If so, the parameters are initialized in each module using its reset_parameters() method. E.g. this code is used to initialize an nn.Linear layer.

Hello @ptrblck , yes that’s what I meant (sorry for my vague lingo, I am new on ML) :slight_smile: .

And also thank you for the feedback. Is your link for the mps build? If so, it’s a bit strange.

After a few tests, I found that when I initialized the Linear layer weights in cuda device with xavier_normal_, I get the same results as with the mps tests. If I don’t do that, the generated images are not OK, the discriminator real or fake losses reach to 0 and generator loss keep increasing.

Yes, I would assume the same reset_parameters method is called for all backends.
It’s interesting to see that different inits seem to be needed for different backends.
Would it be possible to run the same code on the CPU to check which init works (assuming it’s not terribly slow)?

I would love to, however somehow, someway, I started to encounter these errors:

First run, I have this:

NameError                                 Traceback (most recent call last)
/Users/candasunal/Master/Thesis/c3-gan/C3-GAN/Sandbox_CIFAR10_default.py in ()
      6 # In[1]:
----> 9 import torch
      10 from torchvision import datasets
      11 from torch.utils.data import DataLoader

File /opt/anaconda3/envs/torch-nightly/lib/python3.8/site-packages/torch/__init__.py:814, in 
    807         __all__.append(name)
    809 ################################################################################
    810 # Import interface functions defined in Python
    811 ################################################################################
    812 
    813 # needs to be after the above ATen bindings so we can overwrite from Python side
--> 814 from .functional import *  # noqa: F403
    817 ################################################################################
    818 # Remove unnecessary members
    819 ################################################################################
    821 del _StorageBase

File /opt/anaconda3/envs/torch-nightly/lib/python3.8/site-packages/torch/functional.py:7, in 
      5 import torch
      6 from torch._C import _add_docstr
----> 7 import torch.nn.functional as F
      8 from ._lowrank import svd_lowrank, pca_lowrank
      9 from .overrides import (
     10     has_torch_function, has_torch_function_unary, has_torch_function_variadic,
     11     handle_torch_function)

File /opt/anaconda3/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/__init__.py:1, in 
----> 1 from .modules import *  # noqa: F403
      2 from .parameter import (
      3     Parameter as Parameter,
      4     UninitializedParameter as UninitializedParameter,
      5     UninitializedBuffer as UninitializedBuffer,
      6 )
      7 from .parallel import DataParallel as DataParallel

File /opt/anaconda3/envs/torch-nightly/lib/python3.8/site-packages/torch/nn/modules/__init__.py:2, in 
      1 from .module import Module
----> 2 from .linear import Identity, Linear, Bilinear, LazyLinear
      3 from .conv import Conv1d, Conv2d, Conv3d, \
      4     ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, \
      5     LazyConv1d, LazyConv2d, LazyConv3d, LazyConvTranspose1d, LazyConvTranspose2d, LazyConvTranspose3d
...
--> 551 uniform = _make_deprecate(uniform_)
    552 normal = _make_deprecate(normal_)
    553 constant = _make_deprecate(constant_)

NameError: name 'uniform_' is not defined

On the second run, I have this one:

NameError                                 Traceback (most recent call last)
/Users/candasunal/Master/Thesis/c3-gan/C3-GAN/Sandbox_CIFAR10_default.py in ()
      6 # In[1]:
----> 9 import torch
      10 from torchvision import datasets
      11 from torch.utils.data import DataLoader

File /opt/anaconda3/envs/torch-nightly/lib/python3.8/site-packages/torch/__init__.py:233, in 
    219         raise ImportError(textwrap.dedent('''
    220             Failed to load PyTorch C extensions:
    221                 It appears that PyTorch has loaded the `torch/_C` folder
   (...)
    229                 or by running Python from a different directory.
    230             ''').strip()) from None
    231     raise  # If __file__ is not None the cause is unknown, so just re-raise.
--> 233 for name in dir(_C):
    234     if name[0] != '_' and not name.endswith('Base'):
    235         __all__.append(name)

NameError: name '_C' is not defined

Keep in mind that I didn’t install any torch library between now and then, and I run the same code twice and the errors are different.

I don’t know how PyTorch could break without any changes and guess your currently used environment might either have multiple (conflicting) PyTorch versions installed or your current working directory might load script files with the same name as used by PyTorch, e.g. torch.py, which might also result in import errors.