Validation Step fails:RuntimeError: Given groups=1, weight of size 32 4 3 3, expected input[1, 3, 770, 770] to have 4 channels, but got 3 channels instead

I’m trying to train a Unet (efficientnet-b4) with segmentation_models_pytorch and running into this issue

I am using 4 channel input images ( RGB dstack-ed with a Depth image)

train: 81%|▊| 53486/65912 [1:24:38<19:39, 10.53it/
valid: 0%| | 16/16478 [00:02<42:09, 6.51it/s, val_loss_mask: 0.7111, val_loss: 0.7111, val_mask_micr

Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/bhaktatejas922/internal-geometry-ml/src/train.py", line 199, in <module>
    main(cfg)
  File "/home/bhaktatejas922/internal-geometry-ml/src/train.py", line 187, in main
    **cfg.training.fit,
  File "/home/bhaktatejas922/internal-geometry-ml/src/training/runner.py", line 304, in fit
    verbose=verbose,
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 49, in decorate_no_grad
    return func(*args, **kwargs)
  File "/home/bhaktatejas922/internal-geometry-ml/src/training/runner.py", line 335, in evaluate
    output = self._feed_batch(batch)
  File "/home/bhaktatejas922/internal-geometry-ml/src/training/runner.py", line 35, in wrapped
    res = f(*args, **kwargs)
  File "/home/bhaktatejas922/internal-geometry-ml/src/training/runner.py", line 153, in _feed_batch
    output = self.model(*input)
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/segmentation_models_pytorch/base/model.py", line 15, in forward
    features = self.encoder(x)
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/segmentation_models_pytorch/encoders/efficientnet.py", line 50, in forward
    x = self._swish(self._bn0(self._conv_stem(x)))
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/bhaktatejas922/.local/lib/python3.6/site-packages/efficientnet_pytorch/utils.py", line 271, in forward
    x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 32 4 3 3, expected input[1, 3, 770, 770] to have 4 channels, but got 3 channels instead