Optimize code to use GPU

Hello, I want to replicate the Structure Inference Net. What I have so far, based on their Tensorflow code, are the following:

My edge box layer implementation:

class EdgesExtractor(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(self, batch_boxes, image_sizes):
        n_boxes = batch_boxes[0].shape[0]
        edge_boxes = torch.empty((len(batch_boxes), n_boxes**2, 12), device=self.dummy_param.device)

        for batch_idx, rois in enumerate(batch_boxes):
            ious = box_ops.box_iou(rois,rois)
            im_info = image_sizes[batch_idx]


            union_boxes = []
            for i in range(n_boxes):
                for j in range(n_boxes):
                    iou = ious[i][j].item()

                    if iou < 0.6:
                        box = []

                        cx1 = (rois[i][0] + rois[i][2]) * 0.5
                        cy1 = (rois[i][1] + rois[i][3]) * 0.5
                        w1 = (rois[i][2] - rois[i][0]) * 1.0
                        h1 = (rois[i][3] - rois[i][1]) * 1.0

                        if w1 < 0:
                            w1 = 0
                        if h1 < 0:
                            h1 = 0

                        s1 = w1 * h1

                        cx2 = (rois[j][0] + rois[j][2]) * 0.5
                        cy2 = (rois[j][1] + rois[j][3]) * 0.5
                        w2 = (rois[j][2] - rois[j][0]) * 1.0
                        h2 = (rois[j][3] - rois[j][1]) * 1.0

                        if w2 < 0:
                            w2 = 0
                        if h2 < 0:
                            h2 = 0

                        s2 = w2 * h2

                        box.append(w1 / (im_info[0] + 1))
                        box.append(h1 / (im_info[1] + 1))
                        box.append(s1 / ((im_info[0] + 1) * (im_info[1] + 1)))

                        box.append(w2 / (im_info[0] + 1))
                        box.append(h2 / (im_info[1] + 1))
                        box.append(s2 / ((im_info[0] + 1) * (im_info[1] + 1)))

                        box.append((cx1 - cx2) / (w2 + 1))
                        box.append((cy1 - cy2) / (h2 + 1))

                        box.append(pow((cx1 - cx2) / (w2 + 1), 2))
                        box.append(pow((cy2 - cy1) / (h2 + 1), 2))

                        box.append(math.log((w1 + 1) / (w2 + 1)))
                        box.append(math.log((h1 + 1) / (h2 + 1)))

                    else:
                        box = [0] * 12
                        #index += 1

                    union_boxes.append(box)
            edge_boxes[batch_idx] = torch.FloatTensor(union_boxes).to(self.dummy_param.device)
        return edge_boxes

My structure_inference_spmm implementation:

class StructureInference(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.n_steps = 2
        self.n_inputs = 4096

        self.n_hidden_o = 4096
        self.n_hidden_e = 4096

        self.Concat_w = nn.Parameter(torch.randn((self.n_inputs*2), 1))
        self.u = nn.Parameter(torch.randn(12, 1))
        nn.init.xavier_uniform_(self.u)
        nn.init.xavier_uniform_(self.Concat_w)

        self.E_cell = nn.GRU(self.n_hidden_e, self.n_hidden_e)
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(self, mes, mns):
        batch_size = len(mes)
        mns = mns.view(batch_size, int(mns.shape[0]/batch_size), int(mns.shape[1]))
        n_boxes = mns[0].shape[0]
        his = torch.empty(batch_size, n_boxes, self.n_inputs, device=self.dummy_param.device)
        for i in range(batch_size):
            ofe = mes[i]

            fo = mns[i]

            fe = ofe.view(n_boxes**2,12)


            PE = torch.matmul(fe, self.u)
            PE = PE.view(n_boxes, n_boxes)
            hi = fo

            for t in range(self.n_steps):
                X = hi.repeat(n_boxes, 1)
                X = X.view(n_boxes**2, self.n_inputs)
                Y = hi


                Y1 = hi.repeat(1, n_boxes)
                Y1 = Y1.view(n_boxes**2, self.n_inputs)


                Y2 = hi.repeat(n_boxes,1)
                Y2 = Y2.view(n_boxes**2, self.n_inputs)


                VE =  F.tanh(torch.matmul(torch.cat((Y1,Y2), dim=1),  self.Concat_w).view(n_boxes, n_boxes))


                E = torch.multiply(PE, VE)

                Z = F.softmax(E, dim=1)

                X = X.view(n_boxes, n_boxes, self.n_inputs)

                M = Z.unsqueeze(2) * X

                M = torch.max(M, 1).values

                einput = M.squeeze()

                ho2, hi2 = self.E_cell(einput.unsqueeze(0), hi.unsqueeze(0))

                hi = hi2.squeeze()

            his[i] = hi
        return his

I’m trying to optimize this code blocks to utilize the GPU but it seems like my implementation is incorrect.

Take for example, this code block:

    def __init__(
        self
    ):
        super().__init__()
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(self, batch_boxes, image_sizes):
        n_boxes = batch_boxes[0].shape[0]
        edge_boxes = torch.empty((len(batch_boxes), n_boxes**2, 12), device=self.dummy_param.device)
        for batch_idx, proposals in enumerate(batch_boxes):
            ious = box_ops.box_iou(proposals,proposals)

            o_cx2 = (proposals[:, 0] + proposals[:, 2]) * 0.5
            o_cy2 = (proposals[:, 1] + proposals[:, 3]) * 0.5
            o_w2 = (proposals[:, 2] - proposals[:, 0]) * 1.0
            o_h2 = (proposals[:, 3] - proposals[:, 1]) * 1.0

            # set negatives of w2 and h2 to 0
            w2_neg_inds = (o_w2 < 0).nonzero()
            h2_neg_inds = (o_h2 < 0).nonzero()

            o_w2[w2_neg_inds] = 0
            o_h2[h2_neg_inds] = 0

            o_s2 = o_w2 * o_h2

            im_info = image_sizes[batch_idx]

            rolling_idx = 0
            for idx, proposal in enumerate(proposals):
                cur_box_ious = ious[idx]
                inds_g_6s = (cur_box_ious >= 0.6).nonzero().squeeze()
                inds_6s = (cur_box_ious < 0.6).nonzero().squeeze()

                cx2 = o_cx2[inds_6s]
                cy2 = o_cy2[inds_6s]
                w2 = o_w2[inds_6s]
                h2 = o_h2[inds_6s]
                s2 = o_s2[inds_6s]

                cx1 = (proposal[0] + proposal[2]) * 0.5
                cy1 = (proposal[1] + proposal[3]) * 0.5
                w1 = (proposal[2] - proposal[0]) * 1.0
                h1 = (proposal[3] - proposal[1]) * 1.0

                if w1 < 0:
                    w1 = 0
                if h1 < 0:
                    h1 = 0

                s1 = w1 * h1

                num_inds_6s = inds_6s.shape[0]
                edge_boxes_inds_6s = (rolling_idx+inds_6s)
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,0] = (w1 / (im_info[0] + 1)).repeat(num_inds_6s)
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,1] = (h1 / (im_info[1] + 1)).repeat(num_inds_6s)
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,2] = (s1 / ((im_info[0] + 1) * (im_info[1] + 1))).repeat(num_inds_6s)

                edge_boxes[batch_idx][edge_boxes_inds_6s][:,3] = (w2 / (im_info[0] + 1))
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,4] = (h2 / (im_info[1] + 1))
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,5] = (s2 / ((im_info[0] + 1) * (im_info[1] + 1)))

                edge_boxes[batch_idx][edge_boxes_inds_6s][:,6] = ((cx1 - cx2) / (w2 + 1))
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,7] = ((cy1 - cy2) / (h2 + 1))

                edge_boxes[batch_idx][edge_boxes_inds_6s][:,8] = pow((cx1 - cx2) / (w2 + 1), 2)
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,9] = pow((cy1 - cy2) / (h2 + 1), 2)

                edge_boxes[batch_idx][edge_boxes_inds_6s][:,10] = torch.log((w1 + 1) / (w2 + 1))
                edge_boxes[batch_idx][edge_boxes_inds_6s][:,11] = torch.log((h1 + 1) / (h2 + 1))


                edge_boxes_inds_g_6s = rolling_idx+inds_g_6s
                edge_boxes[batch_idx][edge_boxes_inds_g_6s] = torch.FloatTensor([0]*12).to(self.dummy_param.device)

                rolling_idx += proposals.shape[0]

        return edge_boxes