# Backpropagation with different computational graphs in same batch

I have the following issue, and I am trying to write my own code for this issue. I do know that there is something called stochastic computation graph, and http://pytorch.org/docs/master/distributions.html is useful for that. However, I ant to write my own code for something similar.

The model itself is stochastic, i.e. say that whether or not a layer is used, depends upon some random variable, that depends on the data point. Simplest example of such stochastic computation graph is where the probability of using a layer does not depend upon the data point. I stick to explaining the simplest case with single hidden layer.
suppose the original computational graph is as follows (X is the input feature, and softmax is used for predictions)
original : X–> Relu(Linear())—>Softmax()
Now say that I use a random variable Y, which assumes value 0 or 1. I use the Relu(Linear()) layer only if Y = 1. If Y=0, I don’t use the Relu(Linear).
Thus, (say Z is middle layer’s output)
X —>Z = Relu(Linear(X)) only if Y=1, else if Y=0: Z= X —> SoftMax().
Thus, note that the exact computational graph depends upon the data point. I wish to perform backpropagation through the computational graph/s (which is different for each data point).
I have the following two questions:

1. for implementing the above idea, can I do the following by defining my own class.
class Myclass:
def init():
self.linear = Linear
self.relu = Relu
self.softmax = Softmax
self.prob = p
def forward(self,x):
sample = np.random.binomial(1,p)
if sample == 1:
out = self.softmax(self.linear(x))
else:
out = self.softmax(x)
return out
Let us say that I call my model as model.
Then, can I simply use the optimize function as shown in PyTorch website: http://pytorch.org/tutorials/beginner/pytorch_with_examples.html? Like for example, can I use the following method for training?

criterion = torch.nnCrossEntropy()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)
# Compute and print loss
loss = criterion(y_pred, y)
# Zero gradients, perform a backward pass, and update the weights.