How to use amp with a normal optimizer and a sparse optimizer

Hi, I want to use the amp in my model:

class DGLGATNE(nn.Module):
    def __init__(
        self,
        num_nodes,
        embedding_size,
        embedding_u_size,
        edge_types,
        edge_type_count,
        dim_a,
    ):
        super(DGLGATNE, self).__init__()
        self.num_nodes = num_nodes
        self.embedding_size = embedding_size
        self.embedding_u_size = embedding_u_size
        self.edge_types = edge_types
        self.edge_type_count = edge_type_count
        self.dim_a = dim_a

        self.node_embeddings = nn.Embedding(num_nodes, embedding_size, sparse=True)
        self.node_type_embeddings = nn.Embedding(
            num_nodes * edge_type_count, embedding_u_size, sparse=True
        )
        self.trans_weights = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size)
        )
        self.trans_weights_s1 = Parameter(
            torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
        )
        self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))

        self.reset_parameters()

    def reset_parameters(self):
        self.node_embeddings.weight.data.uniform_(-1.0, 1.0)
        self.node_type_embeddings.weight.data.uniform_(-1.0, 1.0)
        self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
        self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    # embs: [batch_size, embedding_size]
    def forward(self, block):
        input_nodes = block.srcdata[dgl.NID]
        output_nodes = block.dstdata[dgl.NID]
        batch_size = block.number_of_dst_nodes()
        node_type_embed = []

        with block.local_scope():
            for i in range(self.edge_type_count):
                edge_type = self.edge_types[i]
                block.srcdata[edge_type] = self.node_type_embeddings(
                    input_nodes * self.edge_type_count + i
                )
                block.dstdata[edge_type] = self.node_type_embeddings(
                    output_nodes * self.edge_type_count + i
                )
                block.update_all(
                    fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type
                )
                node_type_embed.append(block.dstdata[edge_type])

            node_type_embed = torch.stack(node_type_embed, 1)
            tmp_node_type_embed = node_type_embed.unsqueeze(2).view(
                -1, 1, self.embedding_u_size
            )
            trans_w = (
                self.trans_weights.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.embedding_u_size, self.embedding_size)
            )
            trans_w_s1 = (
                self.trans_weights_s1.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.embedding_u_size, self.dim_a)
            )
            trans_w_s2 = (
                self.trans_weights_s2.unsqueeze(0)
                .repeat(batch_size, 1, 1, 1)
                .view(-1, self.dim_a, 1)
            )

            attention = (
                F.softmax(
                    torch.matmul(
                        torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)),
                        trans_w_s2,
                    )
                    .squeeze(2)
                    .view(-1, self.edge_type_count),
                    dim=1,
                )
                .unsqueeze(1)
                .repeat(1, self.edge_type_count, 1)
            )

            node_type_embed = torch.matmul(attention, node_type_embed).view(
                -1, 1, self.embedding_u_size
            )
            node_embed = self.node_embeddings(output_nodes).unsqueeze(1).repeat(
                1, self.edge_type_count, 1
            ) + torch.matmul(node_type_embed, trans_w).view(
                -1, self.edge_type_count, self.embedding_size
            )
            last_node_embed = F.normalize(node_embed, dim=2)

            return last_node_embed  # [batch_size, edge_type_count, embedding_size]

class NSLoss(nn.Module):
    def __init__(self, num_nodes, num_sampled, embedding_size):
        super(NSLoss, self).__init__()
        self.num_nodes = num_nodes
        self.num_sampled = num_sampled
        self.embedding_size = embedding_size

        # [ (log(i+2) - log(i+1)) / log(num_nodes + 1)]
        self.sample_weights = F.normalize(
            torch.Tensor(
                [
                    (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
                    for k in range(num_nodes)
                ]
            ),
            dim=0,
        )
        self.weights = nn.Embedding(num_nodes, embedding_size, sparse=True)
        self.reset_parameters()

    def reset_parameters(self):
        self.weights.weight.data.normal_(std=1.0 / math.sqrt(self.embedding_size))

    def forward(self, input, embs, label):
        n = input.shape[0]
        log_target = torch.log(
            torch.sigmoid(torch.sum(torch.mul(embs, self.weights(label)), 1))
        )
        negs = (
            torch.multinomial(
                self.sample_weights, self.num_sampled * n, replacement=True
            )
            .view(n, self.num_sampled)
            .to(input.device)
        )
        noise = torch.neg(self.weights(negs))
        sum_log_sampled = torch.sum(
            torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1
        ).squeeze()

        loss = log_target + sum_log_sampled
        return -loss.sum() / n

embeddings_params = list(map(id, model.node_embeddings.parameters())) + list(
        map(id, model.node_type_embeddings.parameters())
    )
    weights_params = list(map(id, nsloss.weights.parameters()))

    optimizer = torch.optim.Adam(
        [
            {
                "params": filter(
                    lambda p: id(p) not in embeddings_params, model.parameters(),
                )
            },
            {
                "params": filter(
                    lambda p: id(p) not in weights_params, nsloss.parameters(),
                )
            },
        ],
        lr=1e-4,
    )

    sparse_optimizer = torch.optim.SparseAdam(
        [
            {"params": model.node_embeddings.parameters()},
            {"params": model.node_type_embeddings.parameters()},
            {"params": nsloss.weights.parameters()},
        ],
        lr=1e-4,
    )

The amp in Pytorch 1.6 didn’t support sparse(RuntimeError: Could not run 'aten::_amp_non_finite_check_and_unscale_' with arguments from the 'SparseCUDA' backend. 'aten::_amp_non_finite_check_and_unscale_' is only available for these backends: [CUDA, Autograd, Profiler, Tracer]), so I run the code:

with autocast():
                embs = model(block[0].to(device))[head_invmap]
                embs = embs.gather(
                    1,
                    block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]),
                )[:, 0]
                loss = nsloss(
                    block[0].dstdata[dgl.NID][head_invmap].to(device),
                    embs,
                    tails.to(device),
                )
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.unscale_(sparse_optimizer)
            sparse_optimizer.step()
            scaler.update()

It caused an error:

Traceback (most recent call last):
  File "src/main_sparse.py", line 437, in <module>
    average_auc, average_f1, average_pr = train_model(training_data_by_type)
  File "src/main_sparse.py", line 401, in train_model
    num_workers,
  File "/data/shpx/notebooks/yusang/GATNE/src/utils.py", line 320, in evaluate
    ps, rs, _ = precision_recall_curve(y_true, y_scores)
  File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)
  File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/sklearn/metrics/_ranking.py", line 677, in precision_recall_curve
    sample_weight=sample_weight)
  File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/sklearn/metrics/_ranking.py", line 545, in _binary_clf_curve
    assert_all_finite(y_score)
  File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/sklearn/utils/validation.py", line 72, in inner_f
    return f(**kwargs)
  File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/sklearn/utils/validation.py", line 117, in assert_all_finite
    _assert_all_finite(X.data if sp.issparse(X) else X, allow_nan)
  File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/sklearn/utils/validation.py", line 99, in _assert_all_finite
    msg_dtype if msg_dtype is not None else X.dtype)
ValueError: Input contains NaN, infinity or a value too large for dtype('float32').

Besides, I have tried the pytorch nightly, it support sparse. But I alse encountered the error:

Traceback (most recent call last):
 File "src/main_sparse.py", line 437, in <module>
   average_auc, average_f1, average_pr = train_model(training_data_by_type)
 File "src/main_sparse.py", line 343, in train_model
   scaler.step(sparse_optimizer)
 File "/home/yusang/miniconda3/envs/dev/lib/python3.7/site-packages/torch/cuda/amp/grad_scaler.py", line 306, in step
   assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
AssertionError: No inf checks were recorded for this optimizer.

Is there a way to use amp with the sparse optimizer?

I would recommend to stick to the nightly for now.
The error message indicates, that the sparse_optimizer doesn’t contain any parameters which have scaled gradients. The same error would be raised in this small code example:

model = nn.Linear(8, 8).cuda()
x = torch.randn(1, 8).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer_unused = torch.optim.SGD([nn.Parameter(torch.randn(1))], lr=1.)
scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast():
    out = model(x)
    loss = out.mean()
    scaler.scale(loss).backward()
    scaler.step(optimizer) # works
    scaler.step(optimizer_unused) # your error

Do you mean that I can modify my code as follows?

        with autocast():
               embs = model(block[0].to(device))[head_invmap]
               embs = embs.gather(
                    1,
                    block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2]),
              )[:, 0]
              loss = nsloss(
                    block[0].dstdata[dgl.NID][head_invmap].to(device),
                    embs,
                    tails.to(device),
             )
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            # although the paramaters don't contain any scaled paramaters. I still need to optimize it.
            sparse_optimizer.step()
            scaler.update()

I think your sparse_optimizer contains parameters which were never used in the “amp forward pass”.
If you can verify that that’s the expected case, you should be able to call sparse_optimizer.step() directly without using the GradScaler object.