Federated learning using custom model

Hi guys! I am trying to build a federated learning model. In my scenario, I have 3 workers and an orchestrator. The workers start the training and at the end of each training round, the models are being sent to the orchestrator, the orchestrator calculates the federated average and sends back the new model, the workers train on that new model etc. The custom network is an AutoEncoder that I have built from scratch.

Unfortunately I am getting this error message from the workers: RuntimeError: forward() is missing value for argument ‘inputs’. Declaration: forward(ClassType self, Tensor inputs, Tensor outputs) -> (Tensor) which is weird because my forward function is defined as follows, inside the AE class:

class AutoEncoder(nn.Module):

def __init__(self, code_size):
    super().__init__()
    self.code_size = code_size
    
    # Encoder specification
    self.enc_cnn_1 = nn.Conv2d(3, 10, kernel_size=5)
    self.enc_cnn_2 = nn.Conv2d(10, 20, kernel_size=5)
    self.enc_linear_1 = nn.Linear(53 * 53 * 20, 50)
    self.enc_linear_2 = nn.Linear(50, self.code_size)
    
    # Decoder specification
    self.dec_linear_1 = nn.Linear(self.code_size, 160)
    self.dec_linear_2 = nn.Linear(160, IMAGE_SIZE)
    
def forward(self, images):
    code = self.encode(images)
    out = self.decode(code)
    return out, code

def encode(self, images):
    code = self.enc_cnn_1(images)
    code = F.selu(F.max_pool2d(code, 2))
    
    code = self.enc_cnn_2(code)
    code = F.selu(F.max_pool2d(code, 2))
    code = code.view([images.size(0), -1])
    code = F.selu(self.enc_linear_1(code))

    code = self.enc_linear_2(code)
    return code

def decode(self, code):
    out = F.selu(self.dec_linear_1(code))
    out = F.sigmoid(self.dec_linear_2(out))
    out = out.view([code.size(0), 3, IMAGE_WIDTH, IMAGE_HEIGHT])
    return out

Loss function (cross entropy)

@torch.jit.script
def loss_fn(inputs, outputs):
    return torch.nn.functional.mse_loss(input=inputs, target=outputs)


def set_gradients(model, finetuning):
    """Helper function to exclude all gradients from training
    used for transfer learning in feature extract mode; See: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

    Args:
        model (torch.nn.Module): model object.
        finetuning (bool):  if true, nothing will be changed; transfer learning will be used in finetuning mode, i.e., all gradients are trained;
                            if false, all gradients get excluded from training, used in feature extract mode
    """

    if not finetuning:
        for param in model.parameters():
            param.requires_grad = False
def initialize_model():
    model = AutoEncoder(code_size)
    set_gradients(model, False)
    return model
async def train_model_on_worker(
    worker: websocket_client.WebsocketClientWorker,
    traced_model: torch.jit.ScriptModule,
    dataset_key: str,
    batch_size: int,
    curr_round: int,
    lr: float,
):

    traced_model.train()
    print("train mode on")
    train_config = sy.TrainConfig(
        model=traced_model,
        loss_fn=loss_fn,
        batch_size=batch_size,
        shuffle=True,
        epochs=1,
        optimizer="Adam",
        optimizer_args={"lr": lr}
    )
logger.info(worker.id + " send trainconfig")
    train_config.send(worker)
    print("Model sent to the worker")
    logger.info(worker.id + " start training")
    await worker.async_fit(dataset_key=DATASET_KEY, return_ids=[0])
    logger.info(worker.id + " training done")
    results = dict()
logger.info(worker.id + " get model")
    model = train_config.model_ptr.get().obj

    results["worker_id"] = worker.id
    results["model"] = model

    return results
def validate_model(identifier, model, dataloader, criterion):
model.eval() # changes the mode of the model, in evaluation mode we don't have dropout

    loss = []
    for i, (inputs,_) in enumerat(dataloader):
        print("validation mode on")
        #with torch.set_grad_enabled(False):
        outputs, code  = model(Variable(inputs)) #a tensor with 2 values: one for leak and one for no leak
        loss = criterion(outputs, inputs)
           
        loss = loss.sqrt()
        loss.append(loss.item())
    print("Loss = %.3f" % loss.data)
async def main():
    args = define_and_get_arguments()
    hook = sy.TorchHook(torch) #with this we can override some pytorch methods with pysyft

    # Create WebsocketClientWorkers using IDs, Ports and IP addresses from arguments
    worker_instances = []
    for i in range(len(args.workers) // 3):
        j = i * 3
        worker_instances.append(websocket_client.WebsocketClientWorker(
            id=args.workers[j], port=args.workers[j + 1], host=args.workers[j + 2], hook=hook, verbose=args.verbose))

    model = initialize_model()

    # optional loading of predefined model weights (= dictionary):
    if args.basic_model:
        model.load_state_dict(torch.load(args.basic_model))

    # model serialization (creating an object of type ScriptModule):
    model.eval()
    traced_model = torch.jit.trace(model, torch.rand([1, 3, 224, 224], dtype=torch.float)) #we need to change the form of the model in order to make it 
    #serialisable and send it to the workers

    # Data / picture transformation:
    data_transforms = transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.CenterCrop(INPUT_SIZE),
        transforms.ToTensor()
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Create validation dataset and dataloader
    validation_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'val'), data_transforms)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)

    # Create test dataset and dataloader
    test_dataset = datasets.ImageFolder(os.path.join(args.dataset_path, 'test'), data_transforms)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)

    # Lists to plot loss and acc after training
    train_loss_values = []
    train_acc_values = []
    val_loss_values = []
    val_acc_values = []

    np.set_printoptions(formatter={"float": "{: .0f}".format})

    for curr_round in range(1, args.training_rounds + 1):
        logger.info("Training round %s/%s", curr_round, args.training_rounds)
        print("entered training ")
        # reduce learn rate every 5 training rounds (adaptive learnrate)
        lr = args.lr * (0.1 ** ((curr_round - 1) // 5))
        completed, pending = await asyncio.wait(
            [
                *[
                    train_model_on_worker(
                        worker=worker,
                        traced_model=traced_model,
                        dataset_key=DATASET_KEY,
                        batch_size=args.batch_size,
                        curr_round=curr_round,
                        lr=lr,
                    )
                    for worker in worker_instances
                ]
            ],
            timeout=40
        )
        
        results = []
        for entry in completed:
            print("entry")
            print(entry)
            results.append(entry.result())

        for entry in pending:
            entry.cancel()

        new_worker_instances = []
        for entry in results:
            for worker in worker_instances:
                if (entry["worker_id"] == worker.id):
                    new_worker_instances.append(worker)

        worker_instances = new_worker_instances


        # Setup the loss function
        criterion = torch.nn.functional.mse_loss()
        #optimizer = optimizer_cls(autoencoder.parameters(), lr=lr)

        # Federate models (note that this will also change the model in models[0]
        models = {}
        for worker in results:
            if worker["model"] is not None:
                models[worker["worker_id"]] = worker["model"]

        logger.info("aggregation")
        traced_model = utils.federated_avg(models)
        logger.info("aggregation done")

        # Evaluate federated model
        logger.info("Validate..")
        loss = validate_model("Federated", traced_model, validation_dataloader, criterion)
        logger.info("Validation done")
        val_loss_values.append(loss)
        #val_acc_values.append(acc)
if __name__ == "__main__":
    # Logging setup
    date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
    FORMAT = "%(asctime)s | %(message)s"
    logging.basicConfig(filename='logs/orchestrator_' + date_time + '.log', format=FORMAT)
    logger = logging.getLogger("orchestrator")
    logger.setLevel(level=logging.INFO)

    asyncio.get_event_loop().run_until_complete(main())

The code of the workers:
def load_dataset(dataset_path):
    """Helper function for setting up the local datasets.
    
    Args:
        dataset_path (string):  path to dataset, images must be arranged in this way
                                dataset_path/train/class1/xxx.jpg
                                dataset_path/train/class2/yyy.jpg
    """

    data_transform = transforms.Compose([
        transforms.RandomResizedCrop(INPUT_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    dataset = datasets.ImageFolder(os.path.join(dataset_path, 'train'), data_transform)

    return dataset


def start_websocket_server(id, port, dataset, verbose):
    """Helper function for spinning up a websocket server.
    
    Args:
        id (str or id): the unique id of the worker.
        port (int): the port on which the server should be run.
        dataset: dataset, which the worker should provide.
        verbose (bool): a verbose option - will print all messages sent/received to stdout.
    """

    hook = sy.TorchHook(torch)
    server = WebsocketServerWorker(id=id, host="0.0.0.0", port=port, hook=hook, verbose=verbose)
    server.add_dataset(dataset, key=DATASET_KEY)
    server.start()

    return server

    def _fit(self, model, dataset_key, loss_fn):
        logger = logging.getLogger("worker")
        logger.info(dataset_key)
        print("dataset key")
        model.train()
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=4)
        #data_loader = self._create_data_loader(
            #dataset_key=dataset_key, shuffle=self.train_config.shuffle
        #)
        print("worker")
        print(data_loader)

        loss = None
        iteration_count = 0

        for _ in range(self.train_config.epochs):
            for data in enumerate(data_loader):
                # Set gradients to zero
                self.optimizer.zero_grad()

                # Update model
                output,code = model(data)
                logger.info(data)
                logger.info(output)
                loss = loss_fn(data, output)
                loss.backward()
                self.optimizer.step()

                # Update and check interation count
                iteration_count += 1
                if iteration_count >= self.train_config.max_nr_batches >= 0:
                    break

        return model



if __name__ == "__main__":

    # Parse args
    args = define_and_get_arguments()

    # Logging setup
    date_time = datetime.now().strftime("%m-%d-%Y_%H:%M:%S")
    FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d, p:%(process)d) - %(message)s"
    logging.basicConfig(filename='logs/worker_' + args.id + '_' + date_time + '.log', format=FORMAT)
    logger = logging.getLogger("worker")
    logger.setLevel(level=logging.INFO)

    # Load train dataset
    dataset = load_dataset(args.dataset_path)

    # Start server
    server = start_websocket_server(
        id=args.id,
        port=args.port,
        dataset=dataset,
        verbose=args.verbose,
    )

Does anyone have an idea about what the problem is?

My guess is, that this line of code is raising the issue:

# Setup the loss function
criterion = torch.nn.functional.mse_loss()

Since you are using the functional API, you would have to pass the values directly as:

loss = torch.nn.functional.mse_loss(output, target)

If you want to create the criterion as a module, use this:

criterion = nn.MSELoss()

Let me know, if that helps.

Hi! Thank you very much for your answer. I tried both of the approaches but I am still getting the same error

Are you seeing this error only in the federated learning setup or also in isolation without this utility?

EDIT: Could you post the current stack trace, please, as a small dummy model seems to work using your code (after removing the federated learning code).

Thank you for your help! So the version without the federated learning set up is working perfectly. This is a stack trace I get from the logs of the workers:

ERROR base_events.py(l:1615, p:10310) - Task exception was never retrieved
future: <Task finished coro=<WebsocketServerWorker._producer_handler() done, defined at /home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/workers/websocket_server.py:95> exception=RuntimeError(“forward() is missing value for argument ‘inputs’. Declaration: forward(ClassType self, Tensor inputs, Tensor outputs) -> (Tensor)”)>
Traceback (most recent call last):
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/workers/websocket_server.py”, line 113, in _producer_handler
response = self._recv_msg(message)
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/workers/websocket_server.py”, line 124, in _recv_msg
return self.recv_msg(message)
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/workers/base.py”, line 292, in recv_msg
response = self._message_routermsg_type
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/workers/base.py”, line 412, in execute_command
response = getattr(_self, command_name)(*args, **kwargs)
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/federated/federated_client.py”, line 87, in fit
return self._fit(model=model, dataset_key=dataset_key, loss_fn=loss_fn)
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/syft/federated/federated_client.py”, line 121, in _fit
loss = loss_fn(target=target, pred=output)
File “/home/niki/miniconda3/envs/orchestrator/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 541, in call
result = self.forward(*input, **kwargs)
RuntimeError: forward() is missing value for argument ‘inputs’. Declaration: forward(ClassType self, Tensor inputs, Tensor outputs) -> (Tensor)

That’s still weird and based on your experiments it seems the error is somehow created in the federated learning setup.
Could you remove the jit.script call and check the code again?

Also, if you are still stuck, could you please link me some tutorial on how to create your federated learning setup, so that I could have a look?