qubvel / ttach Goto Github PK
View Code? Open in Web Editor NEWImage Test Time Augmentation with PyTorch!
License: MIT License
Image Test Time Augmentation with PyTorch!
License: MIT License
e-packages/ttach/functional.py", line 7, in rot90
return torch.rot90(x, k, (2, 3))
RuntimeError: Rotation dim0 out of range, dim0 = 2
One quick doubt when we do TTA, are we also predicting on original image? Are the predictions probs mean of original image + the augmented images based upon the transforms ?
I did segmentation problems with standart unet. When I plan to inference the result using the tta library, there a OOM problem occurs. I am using nvidia 2080ti . It's ok when inference using naive inference. Does anyone know how to solve this problems ?
How to install it offline?
thanks for you work! but does it support 5D data? for example MRI data will be BCHWD,in testing ,B will be 1,thanks again!!
Hello,
Thanks for your great work. My quetion seems a little fool.If i want to fix a size just like 512x512,and i want a to implement scale tta(e.g. first:512 x scale;second:resize to 512,),what should i do ?
Thanks!
I use the example like:
model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="tsharpen")
when I use this model to predict, I found the output of model has value "nan"...
I look up the source code of this project, I have found when tsharpen model will do :
x = x**0.5
Is it the negative value in the tensors pass throught this operation will cause "nan"?
Hello!
I am trying to augment a batch of one
If I provide a tensor with shape (1,3,320,320):
augmented_image = transformer.augment_image(img_t.unsqueeze_(0))
I get:
ttach\functional.py:47, in scale(x, scale_factor, interpolation, align_corners)
45 def scale(x, scale_factor, interpolation="nearest", align_corners=None):
46 """scale batch of images byscale_factor
with given interpolation mode"""
---> 47 h, w = x.shape[2:]
48 new_h = int(h * scale_factor)
49 new_w = int(w * scale_factor)
ValueError: too many values to unpack (expected 2)
If I provide a tensor with shape (3,320,320):
augmented_image = transformer.augment_image(img_t)
I get:
ttach\functional.py:7, in rot90(x, k)
5 def rot90(x, k=1):
6 """rotate batch of images by 90 degrees k times"""
----> 7 return torch.rot90(x, k, (2, 3))
RuntimeError: Rotation dim1 out of range, dim1 = 3
What to do?
Thank you!
Thanks for the library!
Is it also possible to add custom transforms to the tta pipeline from let's say e.g., albumentations?
Could anyone please post a sample code of using this TTACH for binary segmentation in PyTorch? Can't seem to get it to work.
Thank you all :)
`import torch
import ttach as tta
import timm
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2
model = torch.load('E:/PhD_Projects/egmentation models/new model weights/UNet_mitb2_thresh0.3.pth')
tta_model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode="mean")
image_dir = 'E:/PhD_Projects/segmentation models/patches'
image_filename_2 = 'image__02_02.tif'
image_path = os.path.join(image_dir, image_filename_2)
image = tiff.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
preprocessing_fn_inference = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
preprocessing_inference=get_preprocessing(preprocessing_fn_inference)
sample = preprocessing_inference(image=image)
image = sample['image']
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = tta_model.predict(x_tensor)
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
pr_mask = (pr_mask.astype('float') * 255.0/16)
#pr_mask = (pr_mask.astype('float') * 255.0/16).astype('uint8')
plt.imshow(pr_mask)
plt.show()`
Can anyone help me with this prediction problem? Thank you. @qubvel
tta_model = tta.ClassificationTTAWrapper(model, tta.aliases.five_crop_transform(384,384))
output = tta_model(images,meta)
like this it cant work
Hello,
Thank you for sharing your code.
I often use a sliding window method to inference a large image.
When I saw your ttach code, I thought it would be better to use FiveCrops or TenCrops for segmentation
However, I found fiveCrops is not working for segmentation.
As we know the shape of orignal image and cropped size, de-augmentation would be possible.
Can you implement FiveCrops and TenCrops for segmentation?
Thank you very much.
In https://github.com/qubvel/ttach/blob/master/ttach/aliases.py#L13
, vlip shoud be vflip?
please clear me TTA concept for segmentation.
lets see i have one test image, then apply flip left,flip right augmentation during testing.
I pass those three images [original,flip-left,flip-right] to model for prediction .
I will get three outputs , after that i have to directly average those prediction or take reverse of augmentation[ i.e again reverse the flipped images to original] and then average the prediction.
please clarify whats the way to merge prediction ?
Hi!
Is it possible to perform tta using Multi-GPU system?
I have got enough resources to calculate complex tta, but it uses only one GPU.
So I've got "out of memory" error on one of my GPUs, although other are free.
sorry, I don't know how to use TTA in my own model, could you please give an example?
I don't know why the prediction speed of the model is so slow when running the tta?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.