Giter Club home page Giter Club logo

federated-learning's Introduction

Federated Learning DOI

This is partly the reproduction of the paper of Communication-Efficient Learning of Deep Networks from Decentralized Data
Only experiments on MNIST and CIFAR10 (both IID and non-IID) is produced by far.

Note: The scripts will be slow without the implementation of parallel computing.




The MLP and CNN models are produced by:


Federated learning with MLP and CNN is produced by:


See the arguments in

For example:

python --dataset mnist --iid --num_channels 1 --model cnn --epochs 50 --gpu 0

--all_clients for averaging over all client models

NB: for CIFAR-10, num_channels must be 3.



Results are shown in Table 1 and Table 2, with the parameters C=0.1, B=10, E=5.

Table 1. results of 10 epochs training with the learning rate of 0.01

Model Acc. of IID Acc. of Non-IID
FedAVG-MLP 94.57% 70.44%
FedAVG-CNN 96.59% 77.72%

Table 2. results of 50 epochs training with the learning rate of 0.01

Model Acc. of IID Acc. of Non-IID
FedAVG-MLP 97.21% 93.03%
FedAVG-CNN 98.60% 93.81%


Acknowledgements give to youkaichao.


McMahan, Brendan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas. Communication-Efficient Learning of Deep Networks from Decentralized Data. In Artificial Intelligence and Statistics (AISTATS), 2017.

Cite As

Shaoxiong Ji. (2018, March 30). A PyTorch Implementation of Federated Learning. Zenodo.

federated-learning's People


ax-22 avatar shaoxiongji avatar xiorcale avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar


 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

federated-learning's Issues

The dataset seems to be in trouble

Hi, When I ran your code locally, I found that the program reported an error when downloading the test dataset. This dataset website can't be accessed normally.

Max number of clients

What is the max number of clients that can be selected in each round of training using this code?




Round   0, Average loss 0.133
Round   1, Average loss 0.097
Round   2, Average loss 0.084
Round   3, Average loss 0.063
Round   4, Average loss 0.075
Round   5, Average loss 0.057
Round   6, Average loss 0.041
Round   7, Average loss 0.049
Round   8, Average loss 0.076
Round   9, Average loss 0.056
Training accuracy: 74.83
Testing accuracy: 75.21


Round   0, Average loss 0.128
Round   1, Average loss 0.068
Round   2, Average loss 0.099
Round   3, Average loss 0.060
Round   4, Average loss 0.057
Round   5, Average loss 0.070
Round   6, Average loss 0.069
Round   7, Average loss 0.057
Round   8, Average loss 0.066
Round   9, Average loss 0.049
Training accuracy: 78.18
Testing accuracy: 78.39

cifar transform

Hello. Thanks for you nice code. But I think the accuracy can be better with the new 'tranform' of cifar:

        trans_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        trans_test = transforms.Compose([
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_train)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_test)

Can multiprocessing speed up the training?

First of all, thank you for your contribution.

I don't understand the statement "Note: The scripts will be slow without the implementation of parallel computing."
What does "parallel computing" mean?
Because as I understand in the code below, each local training performs sequentially.

for idx in idxs_users:
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)

What do you think about multiprocessing with each process corresponding to each client?

Run time error for (without gpu)

When I was running this code, using the command as you suggested,

python --dataset mnist --model cnn --epochs 50 --gpu -1 --num_channels 1

It raised the following error:

(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
(conv2_drop): Dropout2d(p=0.5)
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
0%| | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
File "", line 122, in
w, loss = local.update_weights(net=copy.deepcopy(net_glob))
File "C:\Users\lliubb\PycharmProjects\DistributedLearning_LLM\Fed
Avg\", line 50, in update_weights
for batch_idx, (images, labels) in enumerate(self.ldr_train):
File "C:\Users\lliubb\PycharmProjects\Federated-Learning\venv\lib
\site-packages\torch\utils\data\", line 314, in __next
batch = self.collate_fn([self.dataset[i] for i in indices])
File "C:\Users\lliubb\PycharmProjects\Federated-Learning\venv\lib
\site-packages\torch\utils\data\", line 314, in
batch = self.collate_fn([self.dataset[i] for i in indices])
File "C:\Users\lliubb\PycharmProjects\DistributedLearning_LLM\Fed
Avg\", line 21, in getitem
image, label = self.dataset[self.idxs[item]]
File "C:\Users\lliubb\PycharmProjects\Federated-Learning\venv\lib
\site-packages\torchvision\datasets\", line 68, in getite

img, target = self.train_data[index], self.train_labels[index]
IndexError: only integers, slices (:), ellipsis (...), None and
long or byte Variables are valid indices (got numpy.float64)

Can you give me some hints on how to solve this?
I do not have a gpu and I am using python 3.6 on a windows system.



split dataset

how you partitioned your database between clients ? is that automatically (script name?) or manually ?

Getting Runtime Error

When I try to run the code with the following command:
python --dataset mnist --model cnn --epochs 50 --gpu -1
(since I have no gpu)
I get the following error message:

(conv1): Conv2d(3, 10, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
(conv2_drop): Dropout2d(p=0.5)
(fc1): Linear(in_features=320, out_features=50, bias=True)
(fc2): Linear(in_features=50, out_features=10, bias=True)
0%| | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
File "", line 122, in
w, loss = local.update_weights(net=copy.deepcopy(net_glob))
File "/federated-learning-master/FedAvg/", line 55, in update_weights
log_probs = net(images)
File "/miniconda/envs/fedlearn/lib/python3.6/site-packages/torch/nn/modules/", line 357, in call
result = self.forward(*input, **kwargs)
File "/federated-learning-master/FedAvg/", line 38, in forward
x = F.relu(F.max_pool2d(self.conv1(x), 2))
File "/home/santanu/miniconda/envs/fedlearn/lib/python3.6/site-packages/torch/nn/modules/", line 357, in call
result = self.forward(*input, **kwargs)
File "/miniconda/envs/fedlearn/lib/python3.6/site-packages/torch/nn/modules/", line 282, in forward
self.padding, self.dilation, self.groups)
File "/miniconda/envs/fedlearn/lib/python3.6/site-packages/torch/nn/", line 90, in conv2d
return f(input, weight, bias)
RuntimeError: Given groups=1, weight[10, 3, 5, 5], so expected input[10, 1, 28, 28] to have 3 channels, but got 1 channels instead

Any suggestion how to fix it?

issues of running python --dataset mnist --num_channels 1 --model cnn --epochs 50 --gpu 0

when I tried to run python --dataset mnist --num_channels 1 --model cnn --epochs 50 --gpu 0, then it shows me a problem.

Jians-Air:FedAvg jiansun$ python --dataset mnist --num_channels 1 --model cnn --epochs 50 --gpu 0
Traceback (most recent call last):
File "", line 11, in
from torchvision import datasets, transforms
File "/Library/Python/2.7/site-packages/torchvision/", line 1, in
from torchvision import models


why it is useless when I use bigger num_workers in DataLoader. How can I increase gpu utilization?

Pytorch CrossEntropy function contains softmax

Hi, thanks for your nice code.

However, I find that your code has a bug: you apply CrossEntropy function after softmax activation. But actually pytorch CrossEntropy function itself takes logit as its input.

After removing the softmax activation, I'm able to improve the MLP from 90% to 95%.

fixture 'net_g' not found

When I run the "", an error appears:
`============================= test session starts ==============================
platform linux -- Python 3.6.9, pytest-5.3.1, py-1.8.0, pluggy-0.13.1 -- /home/anaconda3/envs/pytorch/bin/python3.6
cachedir: .pytest_cache
rootdir: /home/federated-learning-master
collecting ... collected 1 item ERROR [100%]
test setup failed
file /home/federated-learning-master/, line 19
def test(net_g, data_loader):
E fixture 'net_g' not found

  available fixtures: cache, capfd, capfdbinary, caplog, capsys, capsysbinary, doctest_namespace, monkeypatch, pytestconfig, record_property, record_testsuite_property, record_xml_attribute, recwarn, tmp_path, tmp_path_factory, tmpdir, tmpdir_factory
  use 'pytest --fixtures [testpath]' for help on them.


How can I solve it?


最近参考大佬您的这个代码学习联邦学习,偶然发现一点令我疑惑的地方。原文中每一个global epoch会随机指定所有clients中的一个fraction进行更新(并不是所有clients都参与更新),聚合的时候原文描述的是所有clients的模型都进行聚合,即没有参与更新的clients的模型也都会参与平均。而代码中的聚合步骤只考虑了参与更新的clients的模型平均。请问代码是不是有问题,还是我的理解错误呢?

for iter in range(args.epochs):
    w_locals, loss_locals = [], []
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    for idx in idxs_users:
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
    # update global weights
    w_glob = FedAvg(w_locals)

    # copy weight to net_glob

Testing accuracy is very low

First thank you for your code.
I have run your code, however, the result is not satisfying.
Training accuracy: 43.00
Testing accuracy: 43.00

my cmd:

python --dataset cifar --num_channels 1 --model cnn --epochs 10 --gpu 0 --iid

look forward to your reply.
best wishes~

About the results of the code

python --dataset mnist --iid --num_channels 1 --model cnn --epochs 50 --gpu 0
In addition
Hi, about, how to run the program results for non-iid data

Runtime error on cuda

`bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)



(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))

(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))

(conv2_drop): Dropout2d(p=0.5)

(fc1): Linear(in_features=320, out_features=50, bias=True)

(fc2): Linear(in_features=50, out_features=10, bias=True)


/opt/conda/lib/python3.6/site-packages/torchvision/datasets/ UserWarning: train_labels has been renamed targets

warnings.warn("train_labels has been renamed targets")

Traceback (most recent call last):

File "", line 113, in

w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))

File "/code/models/", line 48, in train

loss = self.loss_func(log_probs, labels)

File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/", line 489, in call

result = self.forward(*input, **kwargs)

File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/", line 904, in forward

ignore_index=self.ignore_index, reduction=self.reduction)

File "/opt/conda/lib/python3.6/site-packages/torch/nn/", line 1970, in cross_entropy

return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

File "/opt/conda/lib/python3.6/site-packages/torch/nn/", line 1790, in nll_loss

ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

RuntimeError: Expected object of backend CUDA but got backend CPU for argument 'weight'`

I get the above error, only when trying to run it on CUDA.

About the implementation of

I think it's wrong when the data distribution is noniid, should change to:
def FedAvg(w, dict_len):
w_avg = copy.deepcopy(w[0])
for k in w_avg.keys():
w_avg[k] = w_avg[k] * dict_len[0]
for i in range(1, len(w)):
w_avg[k] += w[i][k] * dict_len[i]
w_avg[k] = w_avg[k] / sum(dict_len)
return w_avg
Which dict_len is a list contains number of samples in each clients.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.