Pytorch for causal models

Hi all,

Im interested in coding up a causal model using pytorch and following the ideas within this paper… https://arxiv.org/pdf/2101.10943.pdf

Fundamentally these models take an input array and pass this through a set(s) of layers with 3 outputs. These can be a variety of different architectures but typically the outputs are 1. the propensity for treatment (i.e. a softmax to estimate the probability of treatment from the covariates in the input array) and then two outputs for the effect, the effect with treatment and the effect without treatment.

However in the data we will for each instance of the data, see only one effect output (either if treated or not treated).

So my question is… How can I build a model that trains with target values from only 2 of 3 outputs for any given instance (propensity + either the treated route or the untreated route)

presumably I would need some sort of mapping function that would turn on or off the loss calculation/backprop to treatment/no treatment layers? How could I do this?

Any suggestions welcome