Giter Club home page Giter Club logo

tta_wrapper's Introduction

PyPI version

TTA wrapper

Test time augmnentation wrapper for keras image segmentation and classification models.

Description

How it works?

Wrapper add augmentation layers to your Keras model like this:

          Input
            |           # input image; shape 1, H, W, C
       / / / \ \ \      # duplicate image for augmentation; shape N, H, W, C
      | | |   | | |     # apply augmentations (flips, rotation, shifts)
     your Keras model
      | | |   | | |     # reverse transformations
       \ \ \ / / /      # merge predictions (mean, max, gmean)
            |           # output mask; shape 1, H, W, C
          Output

Arguments

  • h_flip - bool, horizontal flip augmentation
  • v_flip - bool, vertical flip augmentation
  • rotataion - list, allowable angles - 90, 180, 270
  • h_shift - list of int, horizontal shift augmentation in pixels
  • v_shift - list of int, vertical shift augmentation in pixels
  • add - list of int/float, additive factor (aug_image = image + factor)
  • mul - list of int/float, additive factor (aug_image = image * factor)
  • contrast - list of int/float, contrast adjustment factor (aug_image = (image - mean) * factor + mean)
  • merge - one of 'mean', 'gmean' and 'max' - mode of merging augmented predictions together

Constraints

  1. model has to have 1 input and 1 output
  2. inference batch_size == 1
  3. image height == width if rotation augmentation is used

Installation

  1. PyPI package:
$ pip install tta-wrapper
  1. Latest version:
$ pip install git+https://github.com/qubvel/tta_wrapper/

Example

from keras.models import load_model
from tta_wrapper import tta_segmentation

model = load_model('path/to/model.h5')
tta_model = tta_segmentation(model, h_flip=True, rotation=(90, 270), 
                             h_shift=(-5, 5), merge='mean')
y = tta_model.predict(x)

tta_wrapper's People

Contributors

qubvel avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

tta_wrapper's Issues

how to load tta_wrapper model?

I want to save the model with tta_wrapper and reuse it. However, when reloading the model, a repeat layer problem occurs. What is the solution?

I have loaded the reapeat layer from the tta_wrapper's layer, but the only response is that there is no n in init.

TypeError: init() missing 1 required positional argument: 'n'

pytorch

Thanks for your tta jobs.
I have a question.

Is this available for models designed in Pytorch?

AttributeError. Asking for help.

AttributeError Traceback (most recent call last)
in ()
2 model = tta_segmentation(model,
3 h_flip=True,
----> 4 h_shift=(-5, 5),
5 #merge='mean'
6 )

C:\Anaconda3\lib\site-packages\tta_wrapper\wrappers.py in tta_segmentation(model, h_flip, v_flip, h_shift, v_shift, rotation, contrast, add, mul, merge)
52 )
53
---> 54 input_shape = (1, *model.input.shape.as_list()[1:])
55
56 inp = Input(batch_shape=input_shape)

C:\Anaconda3\lib\site-packages\keras\engine\base_layer.py in input(self)
782 if len(self._inbound_nodes) > 1:
783 raise AttributeError('Layer ' + self.name +
--> 784 ' has multiple inbound nodes, '
785 'hence the notion of "layer input" '
786 'is ill-defined. '

AttributeError: Layer sequential_1 has multiple inbound nodes, hence the notion of "layer input" is ill-defined. Use get_input_at(node_index) instead.

import tta_wrapper throwing 'Syntax error'

Below is the traceback and the libraries installed in my conda env.
My suspect is Python 3.5 let me know how to resolve this.
I Believe I need to update the Python 3.5 to 3.7 without disturbing the environment( idk how to do it ) suggestions are welcome

> Traceback (most recent call last):

  File "C:\Users\admin\Miniconda3\envs\machine_learning\lib\site-packages\IPython\core\interactiveshell.py", line 2910, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)

  File "<ipython-input-25-168b6627f61f>", line 1, in <module>
    import tta_wrapper

  File "C:\Users\admin\Miniconda3\envs\machine_learning\lib\site-packages\tta_wrapper\__init__.py", line 1, in <module>
    from .wrappers import tta_segmentation

  File "C:\Users\admin\Miniconda3\envs\machine_learning\lib\site-packages\tta_wrapper\wrappers.py", line 4, in <module>
    from .layers import Repeat, TTA, Merge

  File "C:\Users\admin\Miniconda3\envs\machine_learning\lib\site-packages\tta_wrapper\layers.py", line 58
    raise ValueError(f'Wrong merge type {type}')
                                              ^
SyntaxError: invalid syntax

List of libraries installed in my conda env :

Name Version Build
absl-py 0.2.0
absl-py 0.1.11
absl-py 0.1.10
absl-py 0.1.13
alabaster 0.7.10 py35h3a808de_0
albumentations 0.4.1
anaconda-client 1.6.14 py35_0
asn1crypto 0.24.0 py35_0
astor 0.6.2
babel 2.5.3 py35_0
beautifulsoup4 4.6.0
bleach 2.1.3 py35_0
bleach 1.5.0
bokeh 0.12.15 py35_0
ca-certificates 2018.03.07 0
certifi 2018.4.16 py35_0
cffi 1.11.5 py35h945400d_0
chardet 3.0.4 py35h177e1b7_1
click 6.7 py35h10df73f_0
cliff 2.8.1
cloudpickle 0.5.2 py35_1
clyent 1.2.2 py35h3cd9751_1
cmd2 0.8.5
colorama 0.3.9 py35h32a752f_0
configparser 3.5.0
cryptography 2.2.1 py35hfa6e2cd_0
cssselect 1.0.3
cycler 0.10.0 py35hcc71164_0
cytoolz 0.9.0.1 py35hfa6e2cd_0
dask 0.17.2 py35_0
dask-core 0.17.2 py35_0
decorator 4.2.1 py35_0
decorator 4.3.0
distributed 1.21.4 py35_0
docutils 0.14 py35h8ccb97f_0
efficientnet 1.0.0
entrypoints 0.2.3 py35hb91ced9_2
ffmpeg 4 hf48ec3a_0
freetype 2.8.1 vc14_0
gast 0.2.0
grpcio 1.11.0
grpcio 1.10.0
h5py 2.7.1 py35hb2c3add_0
hdf5 1.10.1 vc14hb361328_0
heapdict 1.0.0 py35_2
html5lib 1
html5lib 1.0.1 py35h047fa9f_0
icc_rt 2017.0.4 h97af966_0
icu 58.2 vc14_0
idna 2.6 py35h8dcb9ae_1
image-classifiers 1.0.0
imageio 2.1.2
imageio 2.3.0 py35_0
imagesize 1.0.0 py35_0
imgaug 0.2.6
intel-openmp 2018.0.0 8
ipykernel 4.8.2 py35_0
ipython 6.2.1 py35h4a2ac14_1
ipython_genutils 0.2.0 py35ha709e79_0
ipywidgets 7.1.2 py35_0
jedi 0.11.1 py35_1
jinja2 2.1 py35hdf652bb_0
jpeg 9b vc14_2
jsonschema 2.6.0 py35h27d56d3_0
jupyter 1.0.0 py35_4
jupyter_client 5.2.3 py35_0
jupyter_console 5.2.0 py35hf76c22e_1
jupyter_core 4.4.0 py35h629ba7f_0
kaggle 1.3.8
kaggle-cli 0.12.13
Keras 2.2.4
Keras-Applications 1.0.8
Keras-Preprocessing 1.1.0
kiwisolver 1.0.1 py35hc605aed_0
libpng 1.6.34 vc14_0
libtiff 4.0.9 vc14_0
locket 0.2.0 py35h0dfcdd0_1
lxml 4.0.0
m2w64-gcc-libgfortran 5.3.0 6
m2w64-gcc-libs 5.3.0 7
m2w64-gcc-libs-core 5.3.0 7
m2w64-gmp 6.1.0 2
m2w64-libwinpthread-git 5.0.0.4634.697f757 2
Markdown 2.6.11
markupsafe 1 py35hc253e08_1
matplotlib 2.2.2 py35h153e9ff_0
MechanicalSoup 0.8.0
mistune 0.8.3 py35_0
mkl 2018.0.2 1
mkl_fft 1.0.1 py35h452e1ab_0
mkl_random 1.0.1 py35h9258bd6_0
moviepy 0.2.3.2
msgpack-python 0.5.6 py35he980bc4_0
msys2-conda-epoch 20160418 1
nbconvert 5.3.1 py35h98d6c46_0
nbformat 4.4.0 py35h908c9d9_0
networkx 2.1 py35_0
nltk 3.2.5
notebook 5.4.1 py35_0
numpy 1.14.2 py35h5c71026_1
numpy 1.14.1
numpy 1.14.2
numpydoc 0.7.0 py35h72ac4f2_0
olefile 0.45.1 py35_0
opencv 3.3.1 py35h20b85fd_1
opencv-python-headless 4.1.1.26
opencv3 3.1.0 py35_0
openssl 1.0.2o h8ea7d77_0
packaging 17.1 py35_0
pandas 0.22.0 py35h6538335_0
pandoc 1.19.2.1 hb2460c7_1
pandocfilters 1.4.2 py35h978f723_1
parso 0.1.1 py35he39c48a_0
partd 0.3.8 py35h894d1e4_0
patsy 0.5.0 py35_0
pbr 4.0.3
pickleshare 0.7.4 py35h2f9f535_0
Pillow 5.1.0
pip 10.0.1
pip 9.0.3 py35_0
plotly 2.7.0
prettytable 0.7.2
progressbar2 3.34.3
prompt_toolkit 1.0.15 py35h89c7cb4_0
protobuf 3.5.2.post1
protobuf 3.5.1
protobuf 3.5.2
psutil 5.4.3 py35hfa6e2cd_0
pycparser 2.18 py35h15a15da_1
pygments 2.2.0 py35h24c0941_0
pyopenssl 17.5.0 py35h75c5b16_0
pyparsing 2.2.0 py35hcabcaab_1
pyperclip 1.6.0
pyqt 5.6.0 py35hd46907b_5
pyreadline 2.1
pysocks 1.6.8 py35_0
python 3.5.5 h0c2934d_1
python-dateutil 2.7.2 py35_0
python-utils 2.3.0
pytz 2018.3 py35_0
pywavelets 0.5.2 py35h7c47ace_0
pywinpty 0.5.1 py35_0
pyyaml 3.12 py35h4bf9689_1
pyzmq 17.0.0 py35hfa6e2cd_0
qt 5.6.2 vc14_1
qtconsole 4.3.1 py35hc47b0dd_0
requests 2.18.4 py35h54a615f_1
scikit-image 0.13.1
scikit-learn 0.19.1 py35h2037775_0
scipy 1.0.1 py35hce232c7_0
seaborn 0.8.1 py35hc73483e_0
segmentation-models 1.0.0
send2trash 1.5.0 py35_0
setuptools 39.0.1 py35_0
setuptools 39.0.1
setuptools 38.5.2
simplegeneric 0.8.1 py35_2
sip 4.18.1 py35h01cbaa7_2
six 1.11.0
six 1.11.0 py35hc1da2df_1
snowballstemmer 1.2.1 py35h4c55bfa_0
sortedcontainers 1.5.9 py35_0
sphinx 1.7.2 py35_0
sphinxcontrib 1 py35h45f5ca3_1
sphinxcontrib-websupport 1.0.1 py35ha3690eb_1
sqlite 3.22.0 vc14_0
statsmodels 0.8.0 py35hfa6034c_0
stevedore 1.28.0
tblib 1.3.2 py35hd2cf7e1_0
tensorboard 1.7.0
tensorflow 1.7.0
tensorflow-gpu 1.5.0
tensorflow-gpu 1.7.0
tensorflow-gpu 1.6.0
tensorflow-tensorboard 1.5.1
termcolor 1.1.0
terminado 0.8.1 py35_1
testpath 0.3.1 py35h06cf69e_0
tk 8.6.7 vc14_0
toolz 0.9.0 py35_0
tornado 5.0.1 py35_1
tqdm 4.11.2
traitlets 4.3.2 py35h09b975b_0
tta-wrapper 0.0.1
typing 3.6.4 py35_0
urllib3 1.22 py35h8cc84eb_0
vc 14 h0510ff6_3
vs2015_runtime 14.0.25123 3
wcwidth 0.1.7 py35h6e80d8a_0
webencodings 0.5.1 py35h5d527fb_1
Werkzeug 0.14.1
wheel 0.30.0
wheel 0.31.0
wheel 0.30.0 py35h38a90bc_1
widgetsnbextension 3.1.4 py35_0
win_inet_pton 1.0.1 py35hbef1270_1
win_unicode_console 0.5 py35h56988b5_0
wincertstore 0.2 py35hfebbdb8_0
winpty 0.4.3 vc14_2
wordcloud 1.4.1
yaml 0.1.7 vc14_0
zict 0.1.3 py35hf5542e0_0
zlib 1.2.11 vc14_0

Any plan to upgrade it for TF Keras?

This is a nice package to have. So are there any plans to upgrade it for TF Keras? I would like to contribute if you are planning to do so. Let me know your thoughts.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.