Giter Club home page Giter Club logo

volo's Introduction

VOLO: Vision Outlooker for Visual Recognition, TPAMI, arxiv

This is a PyTorch implementation of our paper. We present Vision Outlooker (VOLO). We show that our VOLO achieves SOTA performance on ImageNet and CityScapes. No extra training data is used in our work.

ImageNet top-1 accuracy comparison with the state-of-the-art (sota) CNN-based and Transformer-based models. All results are based on the best test resolutions. Our VOLO-D5 achieves SOTA performance on ImageNet without extra data in 2021/06.

(Updating... codes and models for downstream tasks like semantic segmentation are coming soon.)

You may be also interested in our new MLP-like Model: Vision Permutator and our Token Labeling training objective for Vision Transformers.

Reference

@article{yuan2022volo,
  title={Volo: Vision outlooker for visual recognition},
  author={Yuan, Li and Hou, Qibin and Jiang, Zihang and Feng, Jiashi and Yan, Shuicheng},
  journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
  year={2022},
  publisher={IEEE}
}

1. Requirements

torch>=1.7.0; torchvision>=0.8.0; timm==0.4.5; tlt==0.1.0; pyyaml; apex-amp

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

Directory structure in this repo:

│volo/
├──figures/
├──loss/
│  ├── __init__.py
│  ├── cross_entropy.py
├──models/
│  ├── __init__.py
│  ├── volo.py
├──utils/
│  ├── __init__.py
│  ├── utils.py
├──LICENSE
├──README.md
├──distributed_train.sh
├──main.py
├──validate.py

2. VOLO Models

Model #params Image resolution Top1 Acc Download
volo_d1 27M 224 84.2 here
volo_d1 ↑384 27M 384 85.2 here
volo_d2 59M 224 85.2 here
volo_d2 ↑384 59M 384 86.0 here
volo_d3 86M 224 85.4 here
volo_d3 ↑448 86M 448 86.3 here
volo_d4 193M 224 85.7 here
volo_d4 ↑448 193M 448 86.8 here
volo_d5 296M 224 86.1 here
volo_d5 ↑448 296M 448 87.0 here
volo_d5 ↑512 296M 512 87.1 here

All the pretrained models can also be downloaded by BaiDu Yun (password: ttbp).

Usage

Instructions on how to use our pre-trained VOLO models:

from models.volo import *
from utils import load_pretrained_weights 

# create model
model = volo_d1()

# load the pretrained weights
# change num_classes based on dataset, can work for different image size 
# as we interpolate the position embeding for different image size.
load_pretrained_weights(model, "/path/to/pretrained/weights", use_ema=False, 
                        strict=False, num_classes=1000)  

We also provide a Colab notebook which run the steps to perform inference with VOLO.

3. Validation

To evaluate our VOLO models, run:

python3 validate.py /path/to/imagenet  --model volo_d1 \
  --checkpoint /path/to/checkpoint --no-test-pool --apex-amp --img-size 224 -b 128

Change the --img-size from 224 to 384 or 448 for different image resolution, for example, to evaluate volo-d5 on 512 (87.1), run:

python3 validate.py /path/to/imagenet  --model volo_d5 \
  --checkpoint /path/to/volo_d5_512 --no-test-pool --apex-amp --img-size 512 -b 32

4. Train

As we use token labeling, please download the token labeling data in Google Drive or BaiDu Yun (password: y6j2), details about token labling are in here.

For each VOLO model, we first train it with image-size as 224 then finetune on image-size as 384 or 448/512:

train volo_d1 on 224 and finetune on 384 8 GPU, batch_size=1024, 19G GPU-memory in each GPU with apex-amp (mixed precision training)

Train volo_d1 on 224 with 310 epoch, acc=84.2

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model volo_d1 --img-size 224 \
  -b 128 --lr 1.6e-3 --drop-path 0.1 --apex-amp \
  --token-label --token-label-size 14 --token-label-data /path/to/token_label_data

Finetune on 384 with 40 epoch based on the pretrained checkpoint on 224, final acc=85.2 on 384

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model volo_d1 --img-size 384 \
  -b 64 --lr 8.0e-6 --min-lr 4.0e-6 --drop-path 0.1 --epochs 30 --apex-amp \
  --weight-decay 1.0e-8 --warmup-epochs 5  --ground-truth \
  --token-label --token-label-size 24 --token-label-data /path/to/token_label_data \
  --finetune /path/to/pretrained_224_volo_d1/
train volo_d2 on 224 and finetune on 384 8 GPU, batch_size=1024, 27G GPU-memory in each GPU with apex-amp (mixed precision training)

Train volo_d2 on 224 with 300 epoch, acc=85.2

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model volo_d2 --img-size 224 \
  -b 128 --lr 1.0e-3 --drop-path 0.2 --apex-amp \
  --token-label --token-label-size 14 --token-label-data /path/to/token_label_data

Finetune on 384 with 30 epoch based on the pretrained checkpoint on 224, final acc=86.0 on 384

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model volo_d2 --img-size 384 \
  -b 48 --lr 8.0e-6 --min-lr 4.0e-6 --drop-path 0.2 --epochs 30 --apex-amp \
  --weight-decay 1.0e-8 --warmup-epochs 5  --ground-truth \
  --token-label --token-label-size 24 --token-label-data /path/to/token_label_data \
  --finetune /path/to/pretrained_224_volo_d2/
train volo_d3 on 224 and finetune on 448

Train volo_d3 on 224 with 300 epoch, acc=85.4

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model volo_d3 --img-size 224 \
  -b 128 --lr 1.0e-3 --drop-path 0.5 --apex-amp \
  --token-label --token-label-size 14 --token-label-data /path/to/token_label_data

Finetune on 448 with 30 epoch based on the pretrained checkpoint on 224, final acc=86.3 on 448

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet \
  --model volo_d3 --img-size 448 \
  -b 30 --lr 8.0e-6 --min-lr 4.0e-6 --drop-path 0.5 --epochs 30 --apex-amp \
  --weight-decay 1.0e-8 --warmup-epochs 5  --ground-truth \
  --token-label --token-label-size 28 --token-label-data /path/to/token_label_data \
  --finetune /path/to/pretrained_224_volo_d3/

5. Acknowledgement

We gratefully acknowledge the support of NVIDIA AI Tech Center (NVAITC) to this research project, especially the great helps in GPU technology supports from Terry Jianxiong Yin (NVAITC) and Qingyi Tao (NVAITC).

Related project: T2T-ViT, Token_labeling, pytorch-image-models, official imagenet example

LICENSE

This repo is under the Apache-2.0 license. For commercial use, please contact with the authors.

volo's People

Contributors

houqb avatar yuanli2333 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

volo's Issues

Increasing GPU memory in every epoch when running volo-d2 without token labeling.

Hi, thanks for sharing volo, a nice work.
I used bash'''
export CUDA_VISIBLE_DEVICES=1,4,5,6
python -m torch.distributed.launch --nproc_per_node=4 main.py "path/to/dataset"
--model volo_dd2 --img-size 224
-b 100 --lr 1.0e-3 --drop-path 0.2 --epoch 300 --native-amp
--finetune ./d2_224_85.2.pth.tar
GPU memory was increasing when I trained volo-d2 with pretrained model and no token labeling on my own dataset. I added no trick on it and after about 15 epoch it was nearly out of the memory.

SystemError: returned NULL without setting an error

When I run the code of volo, I get a debug error as follows:
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [6,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [40,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [1,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
Traceback (most recent call last):
File "main.py", line 951, in
main()
File "main.py", line 665, in main
optimizers=optimizers)
File "main.py", line 800, in train_one_epoch
create_graph=second_order)
File "/opt/conda/lib/python3.6/site-packages/timm/utils/cuda.py", line 22, in call
scaled_loss.backward(create_graph=create_graph)
File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/opt/conda/lib/python3.6/site-packages/torch/autograd/init.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f9d92671420> returned NULL without setting an error

Can you give me some advice to fix it?

semantic segmentation

Hello, thanks very much for sharing the code for your tremendous research!

For semantic segmentation, did you just run evaluation with multiple square tiles to handle the non-square resolution of Cityscapes? Can you share any details, like decoder head architecture?

RuntimeError: The size of tensor a (28) must match the size of tensor b (14) at non-singleton dimension 2

When use the pre-trained model VOLO-D4-448, the error as flow:
Traceback (most recent call last):
File "F:/volo-main/main1_all_complete.py", line 416, in
main()
File "F:/volo-main/main1_all_complete.py", line 168, in main
train_loss,train_accuracy=train(train_loader,model, loss_f,optimizer,epoch,args)
File "F:/volo-main/main1_all_complete.py", line 239, in train
logits,aux,auxx =model(image)
File "D:\Python36\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "F:\volo-main\models\volo.py", line 614, in forward
x = self.forward_tokens(x)
File "F:\volo-main\models\volo.py", line 579, in forward_tokens
x = x + self.pos_embed
RuntimeError: The size of tensor a (28) must match the size of tensor b (14) at non-singleton dimension 2

Confusion in class OutlookAttention moduel

in class OutlookAttention, there is self.v = nn.Linear(dim, dim, bias=qkv_bias) and the input of this class is x whose shape is B, H, W, C = x.shape. My quesion is how this code v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W can go well without exception because matrix multiplication [B, H, W, C] * [dim, dim] will do here. And also in the original paper, Algorithm 1 implements v_pj = nn.Linear(C, C). But in your codes, C is replaced with dim. Thanks!

Some thoughts about volo

Thanks for your great work! After reading the paper, I have a question:
Can I think volo as a "pixel-wise conditional conv" network?

The reasons are:

  • The weighted average and fold operations together in Fig. 2 are actually a conv operation, except the "conv kernel" is generated from the outlook attention.
  • The outlook attention, i.e. C -> k**4 operation, can be viewed as generating "conv kernel" for all HxW pixels

Combining these two points, I think volo is really like a "pixel-wise conditional conv" network.

Ablation study and official code problem

First, thanks for your contribution, this paper inspired me a lot.
However, I still have some questions as follows, I hope you can answer:

  1. about the code, I think the output shape of Unfold operation is not right in your code, even if it has no error
  2. about the ablation study, it would be better to compare the dynamic convolution with your outlook attention. they are very similar exactly with each other, the only difference is the weights generation method. I am very interested in this.
  3. according to your paper, I modified your code with my own understanding:
    https://github.com/xingshulicc/Vision-In-Transformer-Model/blob/main/outlook_attention.py.
    Hoping you can give me some advice on my code.
    Thanks again.

AttributeError: 'tuple' object has no attribute 'log_softmax'

Traceback (most recent call last):
File "main.py", line 949, in
main()
File "main.py", line 652, in main
train_metrics = train_one_epoch(epoch,
File "main.py", line 784, in train_one_epoch
loss = loss_fn(output, target)
File "C:\Program Files\Anaconda3\envs\yolov5-v4.0\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Program Files\Anaconda3\envs\yolov5-v4.0\lib\site-packages\torch\nn\modules\loss.py", line 961, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "C:\Program Files\Anaconda3\envs\yolov5-v4.0\lib\site-packages\torch\nn\functional.py", line 2468, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "C:\Program Files\Anaconda3\envs\yolov5-v4.0\lib\site-packages\torch\nn\functional.py", line 1605, in log_softmax
ret = input.log_softmax(dim)
AttributeError: 'tuple' object has no attribute 'log_softmax'

Colab Notebook doesn't work & gives wrong results.

Hi Creators, Thanks for making a new SOTA models & also for open-sourcing it.

I was trying the colab notebook and it's throwing an error!!
image

After going through the Usage and adding those params. It gave me different and wrong results compared to demo colab Notebook.
image

colab

please add a google colab for inference

Finetune with 512 image size

Hello,

I am finetuning a model with an image size of 512 and --token-label-size 24. Is the label size enough for 512 image size? Should I use a higher label size? How do I really know the correct label size?

Thank you in advance!

UnboundLocalError: local variable 'input' referenced before assignment

When I run the main program with "python main. py . /data", the following error occurs:

D:\Python36\lib\site-packages\torchvision\transforms\transforms.py:258: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
"Argument interpolation should be of type InterpolationMode instead of int. "
Traceback (most recent call last):
File "main.py", line 948, in
main()
File "main.py", line 664, in main
optimizers=optimizers)
File "main.py", line 746, in train_one_epoch
for batch_idx, (input, target) in enumerate(loader):
File "D:\Python36\lib\site-packages\tlt\data\loader.py", line 105, in iter
yield input, target
UnboundLocalError: local variable 'input' referenced before assignment

Please give me more advice and look forward to your reply. thank you very much.

Balanced class weight

Hello,

I was trying to compute the class weight "balanced". I see that there are two arguments:

parser.add_argument('--dense-weight', type=float, default=0.5,
                    help='Token labeling loss multiplier (default: 0.5)')
parser.add_argument('--cls-weight', type=float, default=1.0,
                    help='Cls token prediction loss multiplier (default: 1.0)')

How can I multiply the loss to get the balanced class weight?

Thank you in advance

run_error

Traceback (most recent call last):
File "main.py", line 960, in
main()
File "main.py", line 670, in main
optimizers=optimizers)
File "main.py", line 779, in train_one_epoch
label_size=args.token_label_size)
File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 90, in mixup_target
y1 = get_labelmaps_with_coords(target, num_classes, on_value=on_value, off_value=off_value, device=device, label_size=label_size)
File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 64, in get_labelmaps_with_coords
num_classes=num_classes,device=device)
File "/opt/conda/lib/python3.6/site-packages/tlt/data/mixup.py", line 12, in get_featuremaps
label_maps_topk_sizes[3]], 0, dtype=torch.float32 ,device=device)
RuntimeError: CUDA error: device-side assert triggered
terminate called after throwing an instance of 'std::runtime_error'
what(): NCCL error in: /pytorch/torch/lib/c10d/../c10d/NCCLUtils.hpp:136, unhandled cuda error, NCCL version 2.7.8

pre trained model file is broken.

I download pre trained model with the link in the document. but when I try to use it, it can't report an error:

hhw@hhw-A01:~/workspace/DeepLearning/VOLO/pretrained_models$ tar -xf d1_384_85.2.pth.tar
tar: This does not look like a tar archive
tar: Skipping to next header
tar: Exiting with failure status due to previous errors

There is a problem when loading the pretrained weights

A problem happen when I load the pretrained weight you provided.


UnpicklingError Traceback (most recent call last)
in
9 # as we interpolate the position embeding for different image size.
10 load_pretrained_weights(model, "/home/featurize/work/checkpoints/archive/data.pkl", use_ema=False,
---> 11 strict=False, num_classes=1000)

/cloud/volo/utils/utils.py in load_pretrained_weights(model, checkpoint_path, use_ema, strict, num_classes)
140 num_classes=1000):
141 '''load pretrained weight for VOLO models'''
--> 142 state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes)
143 model.load_state_dict(state_dict, strict=strict)
144

/cloud/volo/utils/utils.py in load_state_dict(checkpoint_path, model, use_ema, num_classes)
92 if checkpoint_path and os.path.isfile(checkpoint_path):
93 # checkpoint = torch.load(checkpoint_path, map_location='cpu')
---> 94 checkpoint = torch.load(checkpoint_path)
95 state_dict_key = 'state_dict'
96 if isinstance(checkpoint, dict):

/environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
591 return torch.jit.load(opened_file)
592 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
--> 593 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
594
595

/environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
760 "functionality.")
761
--> 762 magic_number = pickle_module.load(f, **pickle_load_args)
763 if magic_number != MAGIC_NUMBER:
764 raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: A load persistent id instruction was encountered,
but no persistent_load function was specified.

size mismatch for pos_embed: copying a param with shape torch.Size([1, 14, 14, 768]) from checkpoint, the shape in current model is torch.Size([1, 14, 14, 512]).

When I use the pre-trained model d5-448, the following error appears:
Traceback (most recent call last):
File "F:/volo-main/main1_all_complete.py", line 415, in
main()
File "F:/volo-main/main1_all_complete.py", line 72, in main
load_pretrained_weights(model, './path/to/pretrained/weights/d5_448_87.0.pth.tar', use_ema=False, strict=False,num_classes=1000)
File "F:\volo-main\utils\utils.py", line 142, in load_pretrained_weights
model.load_state_dict(state_dict, strict=strict)
File "D:\Python36\lib\site-packages\torch\nn\modules\module.py", line 1224, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for VOLO:
size mismatch for pos_embed: copying a param with shape torch.Size([1, 14, 14, 768]) from checkpoint, the shape in current model is torch.Size([1, 14, 14, 512]).
size mismatch for cls_token: copying a param with shape torch.Size([1, 1, 768]) from checkpoint, the shape in current model is torch.Size([1, 1, 512]).
size mismatch for patch_embed.conv.0.weight: copying a param with shape torch.Size([128, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).
size mismatch for patch_embed.conv.1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.1.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.1.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.3.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for patch_embed.conv.4.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.4.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.4.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.4.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.6.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for patch_embed.conv.7.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.7.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.7.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.conv.7.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([64]).
size mismatch for patch_embed.proj.weight: copying a param with shape torch.Size([384, 128, 4, 4]) from checkpoint, the shape in current model is torch.Size([256, 64, 4, 4]).
size mismatch for patch_embed.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.0.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.0.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.0.attn.v.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.0.attn.attn.weight: copying a param with shape torch.Size([972, 384]) from checkpoint, the shape in current model is torch.Size([648, 256]).
size mismatch for network.0.0.attn.attn.bias: copying a param with shape torch.Size([972]) from checkpoint, the shape in current model is torch.Size([648]).
size mismatch for network.0.0.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.0.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.0.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.0.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.0.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 256]).
size mismatch for network.0.0.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for network.0.0.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([256, 768]).
size mismatch for network.0.0.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.1.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.1.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.1.attn.v.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.1.attn.attn.weight: copying a param with shape torch.Size([972, 384]) from checkpoint, the shape in current model is torch.Size([648, 256]).
size mismatch for network.0.1.attn.attn.bias: copying a param with shape torch.Size([972]) from checkpoint, the shape in current model is torch.Size([648]).
size mismatch for network.0.1.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.1.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.1.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.1.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.1.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 256]).
size mismatch for network.0.1.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for network.0.1.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([256, 768]).
size mismatch for network.0.1.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.2.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.2.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.2.attn.v.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.2.attn.attn.weight: copying a param with shape torch.Size([972, 384]) from checkpoint, the shape in current model is torch.Size([648, 256]).
size mismatch for network.0.2.attn.attn.bias: copying a param with shape torch.Size([972]) from checkpoint, the shape in current model is torch.Size([648]).
size mismatch for network.0.2.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.2.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.2.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.2.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.2.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 256]).
size mismatch for network.0.2.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for network.0.2.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([256, 768]).
size mismatch for network.0.2.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.3.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.3.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.3.attn.v.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.3.attn.attn.weight: copying a param with shape torch.Size([972, 384]) from checkpoint, the shape in current model is torch.Size([648, 256]).
size mismatch for network.0.3.attn.attn.bias: copying a param with shape torch.Size([972]) from checkpoint, the shape in current model is torch.Size([648]).
size mismatch for network.0.3.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.3.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.3.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.3.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.3.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 256]).
size mismatch for network.0.3.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for network.0.3.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([256, 768]).
size mismatch for network.0.3.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.4.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.4.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.4.attn.v.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.4.attn.attn.weight: copying a param with shape torch.Size([972, 384]) from checkpoint, the shape in current model is torch.Size([648, 256]).
size mismatch for network.0.4.attn.attn.bias: copying a param with shape torch.Size([972]) from checkpoint, the shape in current model is torch.Size([648]).
size mismatch for network.0.4.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.4.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.4.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.4.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.4.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 256]).
size mismatch for network.0.4.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for network.0.4.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([256, 768]).
size mismatch for network.0.4.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.5.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.5.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.5.attn.v.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.5.attn.attn.weight: copying a param with shape torch.Size([972, 384]) from checkpoint, the shape in current model is torch.Size([648, 256]).
size mismatch for network.0.5.attn.attn.bias: copying a param with shape torch.Size([972]) from checkpoint, the shape in current model is torch.Size([648]).
size mismatch for network.0.5.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([256, 256]).
size mismatch for network.0.5.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.5.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.5.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.0.5.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([768, 256]).
size mismatch for network.0.5.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for network.0.5.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([256, 768]).
size mismatch for network.0.5.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([256]).
size mismatch for network.1.proj.weight: copying a param with shape torch.Size([768, 384, 2, 2]) from checkpoint, the shape in current model is torch.Size([512, 256, 2, 2]).
size mismatch for network.1.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.0.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.0.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.2.0.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.0.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.0.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.0.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.0.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.2.0.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.2.0.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.1.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.1.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.1.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.1.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.2.1.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.1.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.1.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.1.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.1.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.2.1.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.2.1.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.2.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.2.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.2.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.2.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.2.2.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.2.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.2.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.2.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.2.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.2.2.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.2.2.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.3.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.3.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.3.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.3.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.2.3.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.3.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.3.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.2.3.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.2.3.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.2.3.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.2.3.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.0.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.0.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.0.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.0.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.0.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.0.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.0.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.0.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.0.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.1.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.1.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.1.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.1.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.1.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.1.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.1.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.1.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.1.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.1.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.1.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.2.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.2.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.2.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.2.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.2.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.2.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.2.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.2.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.2.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.2.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.2.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.3.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.3.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.3.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.3.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.3.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.3.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.3.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.3.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.3.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.3.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.3.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.4.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.4.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.4.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.4.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.4.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.4.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.4.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.4.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.4.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.4.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.4.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.5.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.5.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.5.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.5.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.5.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.5.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.5.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.5.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.5.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.5.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.5.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.6.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.6.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.6.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.6.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.6.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.6.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.6.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.6.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.6.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.6.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.6.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.7.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.7.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.7.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.7.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.7.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.7.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.7.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.7.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.7.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.7.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.7.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.8.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.8.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.8.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.8.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.8.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.8.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.8.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.8.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.8.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.8.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.8.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.9.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.9.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.9.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.9.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.3.9.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.9.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.9.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.3.9.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.3.9.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.3.9.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.3.9.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.0.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.0.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.4.0.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.0.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.0.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.0.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.0.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.4.0.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.4.0.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.1.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.1.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.1.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.1.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.4.1.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.1.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.1.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.1.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.1.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.4.1.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.4.1.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.2.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.2.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.2.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.2.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.4.2.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.2.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.2.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.2.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.2.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.4.2.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.4.2.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.3.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.3.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.3.attn.qkv.weight: copying a param with shape torch.Size([2304, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.3.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for network.4.3.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.3.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.3.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for network.4.3.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for network.4.3.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for network.4.3.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for network.4.3.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.0.attn.kv.weight: copying a param with shape torch.Size([1536, 768]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
size mismatch for post_network.0.attn.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for post_network.0.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for post_network.0.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.0.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.0.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.0.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for post_network.0.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for post_network.0.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for post_network.0.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.1.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.1.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.1.attn.kv.weight: copying a param with shape torch.Size([1536, 768]) from checkpoint, the shape in current model is torch.Size([1024, 512]).
size mismatch for post_network.1.attn.q.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for post_network.1.attn.proj.weight: copying a param with shape torch.Size([768, 768]) from checkpoint, the shape in current model is torch.Size([512, 512]).
size mismatch for post_network.1.attn.proj.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.1.norm2.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.1.norm2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for post_network.1.mlp.fc1.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([1536, 512]).
size mismatch for post_network.1.mlp.fc1.bias: copying a param with shape torch.Size([3072]) from checkpoint, the shape in current model is torch.Size([1536]).
size mismatch for post_network.1.mlp.fc2.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([512, 1536]).
size mismatch for post_network.1.mlp.fc2.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for aux_head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([1000, 512]).
size mismatch for norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([512]).
size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([1000, 512]).

Token Labeling

Hi, thank you for the paper.

Do you have any numbers for how the networks perform without token labeling, only using MixToken and other augmentations?

AttributeError: 'tuple' object has no attribute 'log_softmax'

I was trying to train VOLO on some custom data using the following command

!python3 main.py -data_dir "/content/dataset" --dataset "ImageFolder" --train-split "train" --val-split "valid" --num-classes 3 --epochs 100 --batch-size 64

Unfortunately, I keep on getting the following error:

Traceback (most recent call last):
  File "main.py", line 948, in <module>
    main()
  File "main.py", line 664, in main
    optimizers=optimizers)
  File "main.py", line 783, in train_one_epoch
    loss = loss_fn(output, target)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/timm/loss/cross_entropy.py", line 35, in forward
    loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 1768, in log_softmax
    ret = input.log_softmax(dim)
AttributeError: 'tuple' object has no attribute 'log_softmax'

Any advice on how I can go about fixing this or what is causing this error to occur?

The code for Semantic Segmentation?

Dear authors:
Thanks for your wonderful work! the result in Semantic Segmentation task looks nice, thus, could the code and config about the Semantic Segmentation be published? Thanks!

Could you please release an ablation study that compares LV-VIT with / without outlook

This is a great job that proposes a new attention way. However, I want to figure out its ability when comparing all the things in the same condition.

Could you please release an ablation study that compares the outlook and attention under the same training policy, hyperparameters (network width, depth), and architectures (for example ViT or LV-ViT)?

So that we can better know the effectiveness of the outlook.

intuition behind Outlook Attention Generation seems does not make sense

Hi,

Thanks for your work.

As your claimed, the generated $W_A$ can work as the weight to aggregate the local context.

However, $W_A$ is generated by a Linear operation along the channel dimension, which indicates the Receptive Field is 1. The neighboring context can not be perceived during the $W_A$ generation process.

Thus, how can $W_A$ encode the relationship information?

Compare to DynamicConv

Hi,

Thanks for your work.

What's main difference between VOLO and DynamicConv?

Though Convolution is not explicitly used,
Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)

An example is provided here:https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html

If no clear difference, I personally thought the claim attention-based models are indeed able to outperform CNNs in the abstract is not accurate . VOLO is more like a hybrid model based on attention and (strengthened) convolution.

When training own dataset, an error occurs when changing numberclasses to the corresponding category. If it is the default, it will report an error

AMP not enabled. Training in float32.
Using native Torch DistributedDataParallel.
Scheduled epochs: 310
/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:312: operator(): block: [0,0,0], thread: [15,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
Traceback (most recent call last):
File "main.py", line 948, in
main()
File "main.py", line 664, in main
optimizers=optimizers)
File "main.py", line 782, in train_one_epoch
output = model(input)
File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 610, in forward
self._sync_params()
File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 1048, in _sync_params
authoritative_rank,
File "/root/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 979, in _distributed_broadcast_coalesced
self.process_group, tensors, buffer_size, authoritative_rank
RuntimeError: CUDA error: device-side assert triggered
terminate called after throwing an instance of 'std::runtime_error'
what(): NCCL error in: /pytorch/torch/lib/c10d/../c10d/NCCLUtils.hpp:136, unhandled cuda error, NCCL version 2.7.8

batch-size error

image

getting this error of
Training with a single process on 1 GPUs.
WARNING: Neither APEX or native Torch AMP is available, using float32. Install NVIDA apex or upgrade to PyTorch 1.6
Model volo_d1 created, param count: 25865120
Data processing configuration for current model + dataset:
input_size: (3, 224, 224)
interpolation: bicubic
mean: (0.485, 0.456, 0.406)
std: (0.229, 0.224, 0.225)
crop_pct: 0.96
AMP not enabled. Training in float32.
Scheduled epochs: 15
Traceback (most recent call last):
File "main.py", line 948, in
main()
File "main.py", line 502, in main
batch_size=args.batch_size)
File "/usr/local/lib/python3.7/dist-packages/timm/data/dataset_factory.py", line 29, in create_dataset
ds = ImageDataset(root, parser=name, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/timm/data/dataset.py", line 31, in init
parser = create_parser(parser or '', root=root, class_map=class_map)
File "/usr/local/lib/python3.7/dist-packages/timm/data/parsers/parser_factory.py", line 22, in create_parser
assert os.path.exists(root)
AssertionError

The Equ.5 and the operate `fold` in the paper do not seem to be consistent.

The Equ.5:

image

In my opinion, this equ calculates the sum of features in the neighborhood corresponding to (i,j).

But in the code:

volo/models/volo.py

Lines 94 to 95 in 1f67923

x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size,
padding=self.padding, stride=self.stride)

F.fold(x, output_size=(H, W), ...) implements another operation.

Let's start with a simple example.

a = torch.randperm(18).float().reshape(1, 2, 3, 3)  # B=1, C=2, H=W=3

a
Out[33]: 
tensor([[[[ 7., 14.,  0.],  # the first channel
          [ 8., 15., 16.],
          [ 6.,  4., 12.]],
         [[ 2., 17., 13.],  # the first channel
          [10.,  9., 11.],
          [ 5.,  3.,  1.]]]])

Collect the adjacent features in the local region 3x3 and stack the nine values along the channel dimension.

unfold_a = F.unfold(a, kernel_size=3, padding=1)  # 1 x C*K*K x H*W

The id of the row indexes the position of the HW space, the number of ids is HW=9

The id of the column indexes the position of the K*K neighbors in different channels.

For the column, it should be noted that this is channel-independent.

That is, first stack the data inside the channels, and then stack the data of different channels in sequence.

For example, the fourth line [ 0., 7., 14., 0., 8., 15., 0., 6., 4.] represents the fourth neighbor in the different windows from the first channel plane.

Specifically, for the first element 0, it represents the fourth value in the flattened K*K window ([[0, 0, 0], [0, 7, 4], [0, 8, 15]]) located at the upper left corner of the padded a.

unfold_a
Out[35]: 
tensor([[
         # the first channel
         [ 0.,  0.,  0.,  0.,  7., 14.,  0.,  8., 15.],  
         [ 0.,  0.,  0.,  7., 14.,  0.,  8., 15., 16.],         
         [ 0.,  0.,  0., 14.,  0.,  0., 15., 16.,  0.],
         [ 0.,  7., 14.,  0.,  8., 15.,  0.,  6.,  4.],
         [ 7., 14.,  0.,  8., 15., 16.,  6.,  4., 12.],
         [14.,  0.,  0., 15., 16.,  0.,  4., 12.,  0.],
         [ 0.,  8., 15.,  0.,  6.,  4.,  0.,  0.,  0.],
         [ 8., 15., 16.,  6.,  4., 12.,  0.,  0.,  0.],
         [15., 16.,  0.,  4., 12.,  0.,  0.,  0.,  0.],
         # the second channel
         [ 0.,  0.,  0.,  0.,  2., 17.,  0., 10.,  9.],
         [ 0.,  0.,  0.,  2., 17., 13., 10.,  9., 11.],
         [ 0.,  0.,  0., 17., 13.,  0.,  9., 11.,  0.],
         [ 0.,  2., 17.,  0., 10.,  9.,  0.,  5.,  3.],
         [ 2., 17., 13., 10.,  9., 11.,  5.,  3.,  1.],
         [17., 13.,  0.,  9., 11.,  0.,  3.,  1.,  0.],
         [ 0., 10.,  9.,  0.,  5.,  3.,  0.,  0.,  0.],
         [10.,  9., 11.,  5.,  3.,  1.,  0.,  0.,  0.],
         [ 9., 11.,  0.,  3.,  1.,  0.,  0.,  0.,  0.]]])

the original data

a
Out[38]: 
tensor([[[[ 7., 14.,  0.],
          [ 8., 15., 16.],
          [ 6.,  4., 12.]],
         [[ 2., 17., 13.],
          [10.,  9., 11.],
          [ 5.,  3.,  1.]]]])

Inconsistent with the equation of the paper.

fold_a = F.fold(unfold_a, output_size=(3, 3), kernel_size=3, padding=1)

fold_a
Out[40]: 
tensor([[[[ 28.,  84.,   0.],
          [ 48., 135.,  96.],
          [ 24.,  24.,  48.]],
         [[  8., 102.,  52.],
          [ 60.,  81.,  66.],
          [ 20.,  18.,   4.]]]])

For simplicity, we only look at the first channel plane.

[[ 28.,  84.,   0.],
  [ 48., 135.,  96.],
  [ 24.,  24.,  48.]]

Comparing the previous unfold_a, we can find that the value of each element here is actually the result of the accumulation of different windows after returning to their original position in padded a:

  1. In the first step, the values of different windows are placed back to their original corresponding positions in sequence.
  2. In the second step, the values corresponding to the same element positions are accumulated.
    3. For examples, fold_a[0, 0, 0, 1] = 84 = unfold_a[0, 5, 0] + unfold_a[0, 4, 1] + unfold_a[0, 3, 2] + unfold_a[0, 2, 3]+ unfold_a[0, 1, 4] + unfold_a[0, 0, 5]. These values come from different windows instead of the "same window" expressed in the equ. 5 of the paper.
  3. The third step is to remove padding.

outlook forward pass example

Would you be able to write an example with dummy data for the outlookattention module's forward pass? I am trying to practice, and understand, each step that is taken but I cannot get a reproduction of the forward pass given the code in models/volo.py. Any help would be much appreciated.

Something like this but for outlook attention instead of a transformer.
tm = nn.Transformer(nhead = 16, num_encoder_layers = 12)
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
tm(src, tgt)

volo-d1 training without token label data

Hi,

Congratulations on your excellent work and many thanks for making the code public. I have trained a model using the base settings and no token labels:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 /path/to/imagenet --model volo_d1 --img-size 224 -b 128 --lr 1.6e-3 --drop-path 0.1 --apex-amp

which reached best accuracy 81.72% after 310 epochs. I believe the expected best acc should be about 83.8% which is quite higher than what I get at the moment.

Can you see any issue with the command used to train the model? Any help would be really appreciated.

Best,
Michael

Question about computational complexity formulation of Outlooker Attention

image

Greetings! Thanks for all your inspiring and excellent VOLO work!!!
In reading this paper, I get trouble in comprehending the formulation (8), which depicts the complexity of Outlooker Attention.
I tried to inference the cost from the pytorch-like code provided aforementioned, however cannot get to the formulation (8).
Would you mind providing any insight about the calculation process? Thanks a lot.

AttributeError: 'tuple' object has no attribute 'max'

When I use the code to use the pre-training model for training, I find that the data becomes tuple data when passing through the model, resulting in the model can not continue training。
model = volo_d2()
load_pretrained_weights( model,'./path/to/pretrained/weights/d2_224_85.2.pth.tar', use_ema=False, strict=False,num_classes=2)
print(model)
......
def train(train_loader,model,criterion, optimizer, epoch,args,scheduler=None):
print("train--------------")
avg_loss=0
avg_acc=0
model.train()
for batch_idx,(image,target) in enumerate(train_loader):
# measure data loading time
image,target=Variable(image.cuda()),Variable(target.cuda())
print(type(image))
image=image.cuda()
target=target.cuda()
optimizer.zero_grad()
logits =model(image)

    #m = [t.cpu().numpy() for t in logits]
    #m = [o.cpu().detach() for o in m]
    #logits = torch.tensor(m)
    #logits = torch.tensor([item.cpu().detach().numpy() for item in logits]).cuda()

    print(type(logits))

preds=logits.max(1, keepdim=True)[1] # get the index of the max log-probability

AttributeError: 'tuple' object has no attribute 'max'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.