the-ai-summer / self-attention-cv Goto Github PK
View Code? Open in Web Editor NEWImplementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
Home Page: https://theaisummer.com/
License: MIT License
Implementation of various self-attention mechanisms focused on computer vision. Ongoing repository.
Home Page: https://theaisummer.com/
License: MIT License
Thank you very much for the code. But when I run test_TransUnet.py ,
It starts reporting errors. Why is that?I
`Traceback (most recent call last):
File "self-attention-cv/tests/test_TransUnet.py", line 14, in
test_TransUnet()
File "/self-attention-cv/tests/test_TransUnet.py", line 11, in test_TransUnet
y = model(a)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "self-attention-cv\self_attention_cv\transunet\trans_unet.py", line 88, in forward
y = self.project_patches_back(y)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\dell.conda\envs\myenv\lib\site-packages\torch\nn\functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
Process finished with exit code 1
`
Could you please help me solve it? Thank you.
Thanks for great work ! I noticed nice implementation of this paper (https://arxiv.org/abs/2103.10504) here:
https://github.com/tamasino52/UNETR/blob/main/unetr.py
It would be great if this can also be included in your repo, since it comes with lots of other great features. So we can explore more.
Thanks ~
Hi,
Can you please explain why is the patch_dim set to 1 in TransUNet class? Thank you in advance!
Looking forward to your reply
Thank you for your work with a clear explanation. As you know, ViT doesn't work on small datasets and I am implementing ResNet34 with Pyramid Vision Transformer Version 2 to make it better. The architecture of ViT and PVT V2 is completely different. Could you provide me some help to implement it? please
The code is currently supported for cpu. I tried running for gpu but its given the following error in relative_pos_enc_qkv.py file. I tried making some changes to change device for inputs but its still not working.
/usr/local/lib/python3.7/dist-packages/self_attention_cv/pos_embeddings/relative_pos_enc_qkv.py in forward(self)
36
37 def forward(self):
---> 38 all_embeddings = torch.index_select(self.relative, 1, self.relative_index_2d) # [head_planes , (dim*dim)]
39
40 all_embeddings = rearrange(all_embeddings, ' c (x y) -> c x y', x=self.dim)
RuntimeError: Input, output and indices must be on the current device`
```
Hello!
thanks for sharing this nice repo :)
I'm trying to use ViT to do regression on images. I'd like to predict 6 floats per image.
My understanding is that I'd need to simply define the network as
vit = ViT(img_dim=128,
in_channels=3,
patch_dim=16,
num_classes=6,
dim=512)
and during training call
vit(x)
and compute the loss as MSE instead of CE.
The network actually runs but it doesn't seem to converge. Is there something obvious I am missing?
many thanks!
I see you have recently added the TimesFormer model to this repository. In the paper, they initialize their model weights from ImageNet pretrained weights of ViT. Does your implementation offer this too? Thanks!
I was wondering whether or not you've implemented an example using the network in a 3d medical segmentation task and/or use case? If this network only exports the center slice of a patch then we would need a wrapper function to iterate through all patches in an image to get the final prediction for the entire volume. From the original paper, I assume they choose 10 patches at random from an image during training, but it's not too clear how they pieced everything together during testing.
Your thoughts on this would be greatly appreciated!
See:
Thank you very much for your contribution. As a novice, I have a doubt. In tranf3dseg, the output of the model is the prediction segmentation of the center patch, so how can I get the segmentation of the whole input image? I am looking forward to any reply.
What is the meaning of qkv_channels?
I try to use AxialAttention on gpu, but I get a mistake.Can you give me some tips about using AxialAttention on gpu.
Thanks!
mistake:
RuntimeError: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0
I am wondering if I use say the LinformerEncoder if I have to add the position encoding or if that's already done? From the source files it doesn't seem to be there, but I'm not sure how to include the position encoding as they seem to need the query which isn't available when just passing data directly to the LinformerEncoder. I very well may be missing something any help would be great. Perhaps an example using positional encoding would be good.
Hi,
Thank you for your effort and time in implementing this. I have a quick question, I want to get segmentation for full image not just for the middle token, would it be correct to change self.tokens
to self.p
here:
and change this:
to
y = self.mlp_seg_head(y)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.