How to deal with batches with a gated branch network (self-driving network)

I am currently implementing this paper (https://arxiv.org/abs/1710.02410) in PyTorch and I have ran into an issue training the model.

The input data consists of an image, a measurement scalar, and a control scalar. The image gets passed to a convolutional module and the measurement scalar gets passed through a linear module. the outputs of these are concatenated. The final section of the model consists of several branches. The control scalar determines which branch the concatenated tensor is passed to.

I tested it with a batch size of 1 and it works. However my issue comes when I turn the batch size up. Since each sample in the batch goes to a (possibly different) branch and each sample in a batch is ran through the model in parallel, this causes an issue.

To avoid a bottleneck in the network during training i’ve been attempting to implement multiprocessing but it hasn’t worked out. Here is where im at:

	def gated_branch_function(self, j, control):
		control = int(control.item())
		""" Branches """
		if control == 2 or control == 0:
			output = self.follow_lane_branch(j)
		elif control == 3:
			output = self.left_branch(j)
		elif control == 4:
			output = self.right_branch(j)
		elif control == 5:
			output = self.straight_branch(j)
		else:
			output = self.general_branch(j)		
		return output


	def forward(self, input_data):
		''' Define variables '''
		input_image = input_data[0].permute(0,3,2,1).float()
		input_speed = input_data[1].unsqueeze(dim=1)
		control = input_data[2]

		''' Pass input data into network '''
		imageOutput = self.imageModule(input_image)
		speedOutput = self.measurementModule(input_speed)

		''' Joint sensory '''
		j = torch.cat([imageOutput, speedOutput], 1)
		j = self.jointSensoryModule(j)

		''' Branches '''
		pool = mp.ProcessPool(nodes=os.cpu_count())
		output = pool.map(self.gated_branch_function, j, control)
		return torch.stack(output)

Torch.multiprocessing didn’t work out. Currently I am trying to use pathos for multiprocessing. I am getting an error however:

TypeError: can't pickle SwigPyObject objects

So I am pretty stuck lol. Is there a elegant way of doing this that I am missing? any ideas?

Thanks!!

Update, I’ve got the following code working:

	def gated_branch_function(self, j_batch, control_batch):
		batch_outputs = []
		for j, control in zip(j_batch, control_batch):
			control = int(control.item())
			""" Branches """
			if control == 2 or control == 0:
				output = self.follow_lane_branch(j)
			elif control == 3:
				output = self.left_branch(j)
			elif control == 4:
				output = self.right_branch(j)
			elif control == 5:
				output = self.straight_branch(j)
			else:
				output = self.general_branch(j)
			batch_outputs.append(output)
		return torch.stack(batch_outputs)


	def forward(self, input_data):
		''' Define variables '''
		input_image = input_data[0].permute(0,3,2,1).float()
		input_speed = input_data[1].unsqueeze(dim=1)
		control = input_data[2]

		''' Pass input data into network '''
		imageOutput = self.imageModule(input_image)
		speedOutput = self.measurementModule(input_speed)

		''' Joint sensory '''
		j = torch.cat([imageOutput, speedOutput], 1)
		j = self.jointSensoryModule(j)

		''' Branches '''
		output = self.gated_branch_function(j, control)
		return output

However this results in a bottleneck, so training is slow. Is there anyway I could write gated_branch_function so that each sample in the batch gets sent to their respective branch in parallel?

You could try to add CUDA streams as described here, but would have to take care of the synchronizations to avoid race conditions.

I would recommend to try it out with a “similar” but simple model and check, if you would get any performance gains at all.

I tested it out on a simple model and got a slight increase in performance. I rewrote the function in my model and its working:

	def gated_branch_function(self, j_batch, control_batch):
		batch_output = []
		for j, control in zip(j_batch, control_batch):
			control = int(control.item())
			s = torch.cuda.Stream()
			""" Branches """
			if control == 2 or control == 0:
				with torch.cuda.stream(s):
					output = self.follow_lane_branch(j)
					batch_output.append(output)
			elif control == 3:
				with torch.cuda.stream(s):
					output = self.left_branch(j)
					batch_output.append(output)
			elif control == 4:
				with torch.cuda.stream(s):
					output = self.right_branch(j)
					batch_output.append(output)
			elif control == 5:
				with torch.cuda.stream(s):
					output = self.straight_branch(j)
					batch_output.append(output)
			else:
				with torch.cuda.stream(s):
					output = self.general_branch(j)
					batch_output.append(output)
		torch.cuda.synchronize()
		return torch.stack(batch_output)

I’m guessing this is the best I can do. Still slow but at least a bit faster than before.