Simulator

This script provides a most simplest way to do federated learning with simultor.

Script

import argparse
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

import openfed
from openfed.data import IIDPartitioner, PartitionerDataset

parser = argparse.ArgumentParser(description='Simulator')
parser.add_argument('--props', type=str, default='/tmp/aggregator.json')
args = parser.parse_args()

props = openfed.federated.FederatedProperties.load(args.props)
assert len(props) == 1
props = props[0]

network = nn.Linear(784, 10)
loss_fn = nn.CrossEntropyLoss()

sgd = torch.optim.SGD(
    network.parameters(), lr=1.0 if props.aggregator else 0.1)
fed_sgd = openfed.optim.FederatedOptimizer(sgd, props.role)

maintainer = openfed.core.Maintainer(props, network.state_dict(keep_vars=True))

with maintainer:
    openfed.functional.device_alignment()
    if props.aggregator:
        openfed.functional.count_step(props.address.world_size - 1)

rounds = 10
if maintainer.aggregator:
    api = openfed.API(maintainer, fed_sgd, rounds,
                      openfed.functional.average_aggregation)
    api.run()
else:
    mnist = MNIST(r'/tmp/', True, ToTensor(), download=True)
    fed_mnist = PartitionerDataset(
        mnist, total_parts=100, partitioner=IIDPartitioner())

    dataloader = DataLoader(
        fed_mnist, batch_size=10, shuffle=True, num_workers=0, drop_last=False)

    version = 0
    for outter in range(rounds):
        maintainer.update_version(version)
        maintainer.step(upload=False)

        part_id = random.randint(0, 9)
        fed_mnist.set_part_id(part_id)

        network.train()
        losses = []
        for data in dataloader:
            x, y = data
            output = network(x.view(-1, 784))
            loss = loss_fn(output, y)

            fed_sgd.zero_grad()
            loss.backward()
            fed_sgd.step()
            losses.append(loss.item())
        loss = sum(losses) / len(losses)

        fed_sgd.round()

        maintainer.update_version(version + 1)
        maintainer.package(fed_sgd)
        maintainer.step(download=False)
        fed_sgd.clear_state_dict()
        version += 1

Copy and save these piece of code as examples/run.py.

Run

# Launch 6 process (1 for aggregator, 5 for collaborator) to do simulation.
!python -m openfed.tools.simulator --nproc 6 run.py
[W ProcessGroupGloo.cpp:559] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W ProcessGroupGloo.cpp:559] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W ProcessGroupGloo.cpp:559] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W ProcessGroupGloo.cpp:559] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W ProcessGroupGloo.cpp:559] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[W ProcessGroupGloo.cpp:559] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
100%|███████████████████████████████████████████| 10/10 [00:01<00:00,  5.90it/s]