Is there a way to apply an operation to a single sample in a batch?

I am trying to update a particular mask based on the values shown in each sample. So far, I can update the mask based on the entire batch, but it only updates after all the samples have passed through.

Is there some way I can get in there and call an event or make something happen after each sample in a batch as opposed to the whole batch?

Example pseudo code.
Let’s assume x is a tensor of size [32, 1000]

def forward(self, x):
    x = F.relu(self.fc1(x))
    x = x * mask
    x = F.sigmoid(x)
    mask = mask.update(x)
    return x

This code will successfully update, and the next batch around It will apply the updated mask, but I’m trying to get it to update after every sample, not after the entire batch.

You can use torch.where(mask_for_update, updated_tensor, original_tensor) to only apply the update to updated_tensor. If you want to update given elements in the batch, you could use a mask of shape batch x 1 with the to-be-masked items set to True.

Best regards

Thomas

I’m not sure I’m following. I went into more detail of my issue at hand. I wonder if your approach works? If so, I’m not sure how to apply your suggestions here…

Specifically - I’d like to do the following:

For every sample, I am interested in the topk activated units,
but each topk activated unit for a particular sample should be avoided in the next samples in the batch, so that all the topk activations are exclusive for each sample.

This is why I am inhibiting units with a mask before passing through topk, and the inhibitions are based on the topk activationis of the previous sample. So within a forward pass, it ends up like

x = activation

# run this chunk on each sample, not on the whole batch.
x = x * inhibition_based_on_prev_top_k   # initialized as ones_like(x)
x = get_topk(x)
inhibition_based_on_prev_top_k.update(x)


return x

Are you looking for a for loop over the batch?

Cant he use hooks for this in the forward pass?

Here’s the approach I ended up taking.

def forward(self, x):
        """
        Args:
            x: (tensor) sizedtorch.Size([32, 225])
        """
        x = F.leaky_relu(self.fc1(x)) # [32, 225]

        for i in range(len(x[:,])):
            s = x[i, :] # [1, 225]
            s = s * self.phi() # initialized to ones [1, 225]
            s = get_top_k(s, 10, mask_type="binary")
            self.phi.update(s, 0.1618)
            x[i, :] = s 
            # print(x[i, :])   
        return x

Does that work through backward, though?
If not, you could add the s to a list and torch.stack it for the value to return.

You know it actually worked fine. I needed to add clone() to avoid any in_place operations… (not actually sure why this worked)

x = F.leaky_relu(self.fc1(x)) # [32, 225]

for i in range(len(x[:,])):
            s = x[i, :]
            s = s.clone() * self.phi().clone()
            s = get_top_k(s, 10, mask_type="binary")
            self.phi.update(s, 0.1618)
            x[i, :] = s

and

loss.backward(retain_graph = True)

It sure was slow though!

I did try stack as you suggested (with unbind as well), but it didn’t really seem to give me any boost and was a little more verbose…

        s_list = []

        x = F.leaky_relu(self.fc1(x)) # [32, 225]

        samples = torch.unbind(x)
        for s in samples:
            s = s.clone() * self.phi().clone()
            s = get_top_k(s , 10, mask_type="binary")
            self.phi.update(s, 0.1618)
            s_list.append(s)
        x = torch.stack(s_list).squeeze()

Really slow, but adding clone() helped backward() make sense of it all, though I am not sure why.