Giter Club home page Giter Club logo

lg-fedavg's Introduction

Federated Learning with Local and Global Representations

Pytorch implementation for federated learning with local and global representations.

Correspondence to:

Paper

Think Locally, Act Globally: Federated Learning with Local and Global Representations
Paul Pu Liang*, Terrance Liu*, Liu Ziyin, Ruslan Salakhutdinov, and Louis-Philippe Morency
NeurIPS 2019 Workshop on Federated Learning (distinguished student paper award). (*equal contribution)

If you find this repository useful, please cite our paper:

@article{liang2020think,
  title={Think locally, act globally: Federated learning with local and global representations},
  author={Liang, Paul Pu and Liu, Terrance and Ziyin, Liu and Salakhutdinov, Ruslan and Morency, Louis-Philippe},
  journal={arXiv preprint arXiv:2001.01523},
  year={2020}
}

Installation

First check that the requirements are satisfied:
Python 3.6
torch 1.2.0
torchvision 0.4.0
numpy 1.18.1
sklearn 0.20.0
matplotlib 3.1.2
Pillow 4.1.1

The next step is to clone the repository:

git clone https://github.com/pliang279/LG-FedAvg.git

Data

We run FedAvg and LG-FedAvg experiments on MNIST (link) and CIFAR10 (link). See our paper for a description how we process and partition the data for federated learning experiments.

FedAvg

Results can be reproduced running the following:

MNIST

python main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --results_save run1

CIFAR10

python main_fed.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --results_save run1

LG-FedAvg

Results can be reproduced by first running the above commands for FedAvg and then running the following:

MNIST

python main_lg.py --dataset mnist --model mlp --num_classes 10 --epochs 200 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 3 --results_save run1 --load_fed best_400.pt

CIFAR10

python main_lg.py --dataset cifar10 --model cnn --num_classes 10 --epochs 200 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2 --results_save run1 --load_fed best_1200.pt

MTL

Results can be reproduced running the following:

MNIST

python main_mtl.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 10 --num_layers_keep 5 --results_save run1

CIFAR10

python main_mtl.py --dataset cifar10 --model cnn --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --shard_per_user 2 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 5 --results_save run1

If you use this code, please cite our paper:

@article{liang2019_federated,
  title={Think Locally, Act Globally: Federated Learning with Local and Global Representations},
  author={Paul Pu Liang and Terrance Liu and Ziyin Liu and Ruslan Salakhutdinov and Louis-Philippe Morency},
  journal={ArXiv},
  year={2019},
  volume={abs/2001.01523}
}

Acknowledgements

This codebase was adapted from https://github.com/shaoxiongji/federated-learning.

lg-fedavg's People

Contributors

pliang279 avatar terranceliu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

lg-fedavg's Issues

CNNCifar.weight_keys

Nice Work!
But in Net.py CNNCifar.weight_keys, why are fc layers ahead of conv?

       self.weight_keys = [['fc1.weight', 'fc1.bias'],
                            ['fc2.weight', 'fc2.bias'],
                            ['fc3.weight', 'fc3.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['conv1.weight', 'conv1.bias'],
                            ]

If nothing wrong with my understanding, I suppose it should be like this because you do conv first then feedfoward

        self.weight_keys = [['conv1.weight', 'conv1.bias'],
                            ['conv2.weight', 'conv2.bias'],
                            ['fc1.weight', 'fc1.bias'],
                            ['fc2.weight', 'fc2.bias'],
                            ['fc3.weight', 'fc3.bias'],
                            ]

The Results Reproduced About CIFAR-10 no-IID

Thank you for the open source project. I think this is a very, very important step in federal learning, improving model performance while reducing communication parameters.

But when I use the command 'readme.md'
python main_lgy.py --dataset cifar10 --model CNN --num_classes 10 --epochs 2000 --lr 0.1 --num_users 100 --frac 0.1 --local_ep 1 --local_bs 50 --num_layers_keep 2

I can't seem to get the precision of the results in the paper. (I didn't complete 2000 rounds)

Here are some of the results. This does not seem to reach the accuracy of about 89.66 of cifar-10 in Table1. And the New Test Acc has a similar problem.

Round 297, Avg Loss 0.322, Loss (local): 0.338, Acc (local): 84.78, Loss (Avg): 2.29, Acc (Avg): 11.21, Loss (ens) 2.151, Acc: (ens) 24.88,
Round 298, Avg Loss 0.295, Loss (local): 0.338, Acc (local): 84.80, Loss (Avg): 2.29, Acc (Avg): 10.87, Loss (ens) 2.156, Acc: (ens) 24.39,
Round 299, Avg Loss 0.363, Loss (local): 0.348, Acc (local): 84.51, Loss (Avg): 2.29, Acc (Avg): 10.42, Loss (ens) 2.170, Acc: (ens) 23.85,
Round 300, Avg Loss 0.324, Loss (local): 0.336, Acc (local): 84.95, Loss (Avg): 2.29, Acc (Avg): 10.53, Loss (ens) 2.157, Acc: (ens) 24.84,
Round 301, Avg Loss 0.403, Loss (local): 0.338, Acc (local): 85.08, Loss (Avg): 2.29, Acc (Avg): 10.80, Loss (ens) 2.154, Acc: (ens) 24.28,
Round 302, Avg Loss 0.345, Loss (local): 0.338, Acc (local): 85.25, Loss (Avg): 2.29, Acc (Avg): 11.06, Loss (ens) 2.154, Acc: (ens) 24.38,
Round 303, Avg Loss 0.395, Loss (local): 0.340, Acc (local): 85.22, Loss (Avg): 2.29, Acc (Avg): 10.48, Loss (ens) 2.166, Acc: (ens) 23.43,
Round 304, Avg Loss 0.404, Loss (local): 0.337, Acc (local): 85.24, Loss (Avg): 2.29, Acc (Avg): 10.19, Loss (ens) 2.165, Acc: (ens) 23.34,
Round 305, Avg Loss 0.343, Loss (local): 0.332, Acc (local): 85.44, Loss (Avg): 2.29, Acc (Avg): 10.79, Loss (ens) 2.161, Acc: (ens) 23.24,
Round 306, Avg Loss 0.281, Loss (local): 0.331, Acc (local): 85.54, Loss (Avg): 2.29, Acc (Avg): 10.93, Loss (ens) 2.157, Acc: (ens) 23.86,
Round 307, Avg Loss 0.253, Loss (local): 0.332, Acc (local): 85.57, Loss (Avg): 2.29, Acc (Avg): 10.88, Loss (ens) 2.147, Acc: (ens) 24.66,
Round 308, Avg Loss 0.413, Loss (local): 0.330, Acc (local): 85.61, Loss (Avg): 2.29, Acc (Avg): 11.07, Loss (ens) 2.148, Acc: (ens) 24.40,
Round 309, Avg Loss 0.287, Loss (local): 0.333, Acc (local): 85.44, Loss (Avg): 2.29, Acc (Avg): 10.67, Loss (ens) 2.151, Acc: (ens) 24.78,
Round 310, Avg Loss 0.343, Loss (local): 0.332, Acc (local): 85.44, Loss (Avg): 2.29, Acc (Avg): 10.79, Loss (ens) 2.146, Acc: (ens) 23.75,
Round 311, Avg Loss 0.355, Loss (local): 0.331, Acc (local): 85.44, Loss (Avg): 2.29, Acc (Avg): 10.70, Loss (ens) 2.154, Acc: (ens) 23.51,
Round 312, Avg Loss 0.326, Loss (local): 0.333, Acc (local): 85.34, Loss (Avg): 2.29, Acc (Avg): 10.56, Loss (ens) 2.151, Acc: (ens) 24.29,
Round 313, Avg Loss 0.264, Loss (local): 0.333, Acc (local): 85.37, Loss (Avg): 2.29, Acc (Avg): 10.57, Loss (ens) 2.155, Acc: (ens) 23.30,
Round 314, Avg Loss 0.349, Loss (local): 0.334, Acc (local): 85.36, Loss (Avg): 2.29, Acc (Avg): 10.43, Loss (ens) 2.158, Acc: (ens) 22.82,
Round 315, Avg Loss 0.327, Loss (local): 0.333, Acc (local): 85.20, Loss (Avg): 2.29, Acc (Avg): 10.83, Loss (ens) 2.155, Acc: (ens) 23.32,

Is this because Rounds is not enough? Or is there a problem with my understanding, And I hope someone can help.

Thank you again for your outstanding contribution

Failed to converge when changing num_users and frac

Description

When I change the num_user to 10 and frac to 0.3 with --iid, which means each epoch there are 3 client been choosen, I find the model become better then worse.

Reproduce

$ python main_fed.py --dataset mnist --model mlp --num_classes 10 --epochs 1000 --lr 0.05 --num_users 10 --shard_per_user 2 --frac 0.3 --local_ep 1 --local_bs 8 --results_save run1 --iid

Out

device: cuda:0
MLP(
  (layer_input): Linear(in_features=784, out_features=512, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (layer_hidden1): Linear(in_features=512, out_features=256, bias=True)
  (layer_hidden2): Linear(in_features=256, out_features=256, bias=True)
  (layer_hidden3): Linear(in_features=256, out_features=128, bias=True)
  (layer_out): Linear(in_features=128, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)
Round 0, lr: 0.050000, [5 6 0]
Round   0, Average loss 2.038, Test loss 1.794, Test accuracy: 67.63
Round 1, lr: 0.050000, [6 4 5]
Round   1, Average loss 1.748, Test loss 1.611, Test accuracy: 85.05
Round 2, lr: 0.050000, [7 9 4]
Round   2, Average loss 1.761, Test loss 1.717, Test accuracy: 74.39
Round 3, lr: 0.050000, [7 4 9]
Round   3, Average loss 1.856, Test loss 1.843, Test accuracy: 61.74
Round 4, lr: 0.050000, [9 2 5]
Round   4, Average loss 1.948, Test loss 1.863, Test accuracy: 59.83
Round 5, lr: 0.050000, [2 6 7]
Round   5, Average loss 2.039, Test loss 1.990, Test accuracy: 47.11
Round 6, lr: 0.050000, [0 7 2]
Round   6, Average loss 2.025, Test loss 1.997, Test accuracy: 46.39
Round 7, lr: 0.050000, [4 3 2]
Round   7, Average loss 2.017, Test loss 2.104, Test accuracy: 35.68
Round 8, lr: 0.050000, [2 9 1]
Round   8, Average loss 2.128, Test loss 2.113, Test accuracy: 34.82
Round 9, lr: 0.050000, [2 7 5]
Round   9, Average loss 2.127, Test loss 2.190, Test accuracy: 27.09
Round 10, lr: 0.050000, [1 9 7]
Round  10, Average loss 2.194, Test loss 2.239, Test accuracy: 22.21
Round 11, lr: 0.050000, [0 2 3]
Round  11, Average loss 2.236, Test loss 2.186, Test accuracy: 27.53
Round 12, lr: 0.050000, [3 9 5]
Round  12, Average loss 2.188, Test loss 2.108, Test accuracy: 35.29
Round 13, lr: 0.050000, [3 6 5]
Round  13, Average loss 2.172, Test loss 2.237, Test accuracy: 22.45
Round 14, lr: 0.050000, [9 8 4]
Round  14, Average loss 2.258, Test loss 2.175, Test accuracy: 28.61
Round 15, lr: 0.050000, [2 7 1]
Round  15, Average loss 2.178, Test loss 2.161, Test accuracy: 29.99
Round 16, lr: 0.050000, [9 6 4]
Round  16, Average loss 2.192, Test loss 2.280, Test accuracy: 18.10
Round 17, lr: 0.050000, [2 4 0]
Round  17, Average loss 2.284, Test loss 2.125, Test accuracy: 33.60
Round 18, lr: 0.050000, [4 1 0]
Round  18, Average loss 2.226, Test loss 2.352, Test accuracy: 10.94
Round 19, lr: 0.050000, [6 0 7]
Round  19, Average loss 2.355, Test loss 2.352, Test accuracy: 10.94
Round 20, lr: 0.050000, [1 8 6]
Round  20, Average loss 2.351, Test loss 2.339, Test accuracy: 12.24
Round 21, lr: 0.050000, [1 2 3]
Round  21, Average loss 2.338, Test loss 2.339, Test accuracy: 12.24
Round 22, lr: 0.050000, [9 3 1]
Round  22, Average loss 2.340, Test loss 2.339, Test accuracy: 12.24
Round 23, lr: 0.050000, [4 2 0]
Round  23, Average loss 2.337, Test loss 2.339, Test accuracy: 12.24
Round 24, lr: 0.050000, [8 1 5]

Question about Eq. 1

Hi, thanks for your code. However, I noticed that there is no implementation of Eq. 1 in your paper. Eq. 1 is leveraged in the training of the local model in LG-FedAvg if I have the right understanding. Could you please show me the code of Eq. 1? Thanks.

Wrong loss function?

I notice that the code uses CrossEntropyLoss for local training:

self.loss_func = nn.CrossEntropyLoss()

And it accepts the log-probabilities as input:

loss = self.loss_func(log_probs, labels)

The output of CNN networks is also logsoftmax:

return F.log_softmax(x, dim=1)

But according to the doc of PyTorch, CrossEntropyLoss already has logsoftmax inside:
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

I think the loss should be calculated via NLLLoss instead if used with input after logsoftmax (https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss).

Or is there any reason why the code uses logsoftmax twice for calculating the loss?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.