tsb0601 / emp-ssl Goto Github PK
View Code? Open in Web Editor NEWThis repository contains the implementation for the paper "EMP-SSL: Towards Self-Supervised Learning in One Training Epoch."
This repository contains the implementation for the paper "EMP-SSL: Towards Self-Supervised Learning in One Training Epoch."
Hi,
Congratulations for such a wonderful work in SSL domain.
My question is concerned with the experiment configurations. It seems that fixed patch sizes were used in comparison and ablation experiments, which can allow the model to "see" image information extremely more times, compared to image-based augmentation methods. It is reasonable to reduce the patch size with an increase in patch number, i.e. maintain the pixel count for a single forward pass with augmented images or patches as input.
Thanks for your nice ideas.
I just test your codes and it turns out that if I use the code
.chunk(num_patches, dim=0)
You are actually splitting different samples in one group and encourage them to be similar. What you should do is set dim=1 to group augmented views from one sample
thank you for the code!!! i am now research image represention use no labels.so the model mightbe very useful.i have trained 200 epoches.
my issue is when net in test mode,i don't understand why still need patches?can i set patches=1,which means only use origin image.i have tried that, but the top-1 accuracy is decrease to 88%.the strange is when i only use "normalize ,ToTensor",top-1 accuracy is decrease to 71%.
Hi
It's an amazing work and I really appreciate your paper and open source code.
However, I tried to train the model on my server and encountered an issue.
File "/home//research/empssl/main.py", line 162, in main
loss_TCR = cal_TCR(z_proj, criterion, num_patches)
File "/home//research/empssl/main.py", line 100, in cal_TCR
loss += criterion(z_list[i])
File "/home//software/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home//research/empssl/loss.py", line 90, in forward
return - self.compute_discrimn_loss(X.T)
File "/home//research/empssl/loss.py", line 86, in compute_discrimn_loss
logdet = torch.logdet(I + scalar * W.matmul(W.T))
RuntimeError
......final some rows
extern "C"
__launch_bounds__(512, 4)
__global__ void reduction_prod_kernel(ReduceJitOp r){
r.run();
}
nvrtc: error: invalid value for --gpu-architecture (-arch)
Do you have any suggestions for it? Maybe I will try it on Google Colab.
#5 hello, thank for your contribution!But i have met a problem,when i want to run your code.Could please help me to solve this problem? thanks!
0it [02:28, ?it/s]
Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/tqdm/std.py", line 1178, in iter
for obj in iterable:
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 435, in next
data = self._next_data()
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
return self._process_data(data)
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
data.reraise()
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torchvision/datasets/cifar.py", line 120, in getitem
img = self.transform(img)
File "/home/ubuntu/Ferry/PycharmExperiment/EMP-SSL_main/dataset/aug.py", line 136, in call
augmented_x = [aug_transform(x) for i in range(self.num_patch)]
File "/home/ubuntu/Ferry/PycharmExperiment/EMP-SSL_main/dataset/aug.py", line 136, in
augmented_x = [aug_transform(x) for i in range(self.num_patch)]
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 67, in call
img = t(img)
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 823, in forward
i, j, h, w = self.get_params(img, self.scale, self.ratio)
File "/home/ubuntu/anaconda3/envs/pytorch/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 787, in get_params
log_ratio = torch.log(torch.tensor(ratio))
RuntimeError: log_vml_cpu not implemented for 'Long'
python-BaseException
Thank you so much for providing such a valuable perspective. I have some doubts about the paper can you please answer the following.
In Table III in section 3.2, I see that e.g. BYOL is not much different in time from your method in the last row. What is the significance of reducing the epoch but not the time?
The second point is that adding multiple positive samples to an image in the data enhancement phase will increase the hardware requirements, is there any solution considered in this regard?
I am really confused by your claim of reducing two orders of magnitudes training epochs. Since you are cropping one image to 200 same-size images, is it equivalent to a very heavy augmentation? Even though you have much fewer epochs, the actual computational cost/GPU hours do not reduce significantly compared to other methods. (Which is also reflected in Table 3). Did i miss something or is it what you actually did?
why use label in main.py? SSL do not need label
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.