Giter Club home page Giter Club logo

understanding-unets's Introduction

Learnlets

Build Status

Learnlets are a way to learn a filter bank rather than design one like in the curvelets.

This filter bank will be learned in a denoising setting with backpropagation and gradient descent.

Requirements

The requirements are listed in learning_wavelets/requirements.txt.

Use

The learnlets are defined in learning_wavelets/learnlet_model.py, via the class Learnlet.

You can use different types of thresholding listed in learning_wavelets/keras_utils/thresholding.py.

List of saved networks

Exact reconstruction notebook

Model id Params
learnlet_dynamic_st_bsd500_0_55_1580806694 the big classical network, with 256 filters + identity
learnlet_subclassing_st_bsd500_0_55_1582195807 64 filters, subclassed API, exact recon forced

No threshold notebook

Model id Params
learnlet_dynamic_st_bsd500_0_55_1580806694 the big classical network, with 256 filters + identity

Different training noise standard deviations notebook

Model id Params
learnlet_dynamic_st_bsd500_0_55_1580806694 the big classical network, with 256 filters + identity
learnlet_dynamic_st_bsd500_20_40_1580492805 same with training on 20;40 noise std
learnlet_dynamic_st_bsd500_30_1580668579 same with training on 30 noise std
unet_dynamic_st_bsd500_0_55_1576668365 big classical unet with 64 base filters and batch norm
unet_dynamic_st_bsd500_20.0_40.0_1581002329 same with training on 20;40 noise std
unet_dynamic_st_bsd500_30.0_30.0_1581002329 same with training on 30 noise std

General comparison

Model id Params
learnlet_dynamic_st_bsd500_0_55_1580806694 the big classical network, with 256 filters + identity
unet_dynamic_st_bsd500_0_55_1576668365 big classical unet with 64 base filters and batch norm

understanding-unets's People

Contributors

aziz-ayed avatar kevinmicha avatar zaccharieramzi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

understanding-unets's Issues

CheekyDynamicThresholding Layer causes model save fail

This is due to tensorflow/tensorflow#36650.

Therefore as a patch, I need to name the weights I have in this custom layer.

Current error:

RuntimeError                              Traceback (most recent call last)
<timed eval> in <module>

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    395                       total_epochs=1)
    396                   cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
--> 397                                  prefix='val_')
    398 
    399     return model.history

/usr/lib/python3.6/contextlib.py in __exit__(self, type, value, traceback)
     86         if type is None:
     87             try:
---> 88                 next(self.gen)
     89             except StopIteration:
     90                 return False

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_v2.py in on_epoch(self, epoch, mode)
    769       if mode == ModeKeys.TRAIN:
    770         # Epochs only apply to `fit`.
--> 771         self.callbacks.on_epoch_end(epoch, epoch_logs)
    772       self.progbar.on_epoch_end(epoch, epoch_logs)
    773 

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
    300     logs = logs or {}
    301     for callback in self.callbacks:
--> 302       callback.on_epoch_end(epoch, logs)
    303 
    304   def on_train_batch_begin(self, batch, logs=None):

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
    990           self._save_model(epoch=epoch, logs=logs)
    991       else:
--> 992         self._save_model(epoch=epoch, logs=logs)
    993     if self.model._in_multi_worker_mode():
    994       # For multi-worker training, back up the weights and current training

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/callbacks.py in _save_model(self, epoch, logs)
   1038             self.model.save_weights(filepath, overwrite=True)
   1039           else:
-> 1040             self.model.save(filepath, overwrite=True)
   1041 
   1042         self._maybe_remove_file()

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
   1006     """
   1007     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 1008                     signatures, options)
   1009 
   1010   def save_weights(self, filepath, overwrite=True, save_format=None):

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    110           'or using `save_weights`.')
    111     hdf5_format.save_model_to_hdf5(
--> 112         model, filepath, overwrite, include_optimizer)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
    107     model_weights_group = f.create_group('model_weights')
    108     model_layers = model.layers
--> 109     save_weights_to_hdf5_group(model_weights_group, model_layers)
    110 
    111     # TODO(b/128683857): Add integration tests between tf.keras and external

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_weights_to_hdf5_group(f, layers)
    629     save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
    630     for name, val in zip(weight_names, weight_values):
--> 631       param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
    632       if not val.shape:
    633         # scalar

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/h5py/_hl/group.py in create_dataset(self, name, shape, dtype, data, **kwds)
    137             dset = dataset.Dataset(dsid)
    138             if name is not None:
--> 139                 self[name] = dset
    140             return dset
    141 

~/workspace/understanding-unets/venv/lib/python3.6/site-packages/h5py/_hl/group.py in __setitem__(self, name, obj)
    371 
    372             if isinstance(obj, HLObject):
--> 373                 h5o.link(obj.id, self.id, name, lcpl=lcpl, lapl=self._lapl)
    374 
    375             elif isinstance(obj, SoftLink):

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()

h5py/_objects.pyx in h5py._objects.with_phil.wrapper()

h5py/h5o.pyx in h5py.h5o.link()

RuntimeError: Unable to create link (name already exists)

Clear legacy functions, scripts and notebooks, rewrite everything with subclassed model

A lot of the clutter in the code is due to the fact that I wanted to handle multiple ways of writing my model:

  • old functionnal API with a recursive definition
  • new functionnal API with an iterative definition
  • subclassed API with possibility to do exact reconstruction and reweighting very conveniently.

In particular, a lot of if-else in the normalisation cback are due to that.
Also for new users it is now impossible to read the code because of all these different APIs.

Before moving forward to new noise types, undecimated, and curvelets init or complement, I need to clear everything.

Problem: I will need to do that on a separate branch to have the results ready to be re-generated with solved models.

Doc before it's too late

Because some ppl are asking to have access to the repo, plus it might be used afterwards, I need to doc this...

Include DnCNN in the comparison

We need to have DnCNN in our comparison as it is widely used as an advanced denoising benchmark, in particular for blind gaussian denoising.

Have figures notebooks restored

Because of #25 the figures notebooks have been changed to some old state.
We should take advantage of the reweighting one to change the others.

Downsampling and upsampling with biorthogonal strategy

The up and downsampling currently implemented uses the following strategy:

  • down: pick one line/column out of 2.
  • up: bicubic interpolation

We want to have a potentially smarter way of doing this, still close to the wavelets, since the classical NN strategy (average pooling, bilinear upsampling) worked a bit better.

For this, the biorthogonal strategy seems adapted:

  • down: same thing, take one line out of 2 (after application of the h filter but this doesn't change).
  • up: interleave 0s on the lines, then 1D convolution with the \tilde{h} filter on the lines, then interleave 0s on the columns and same 1D convolution but on columns.

Normalisation adjustment callback shouldn't work

Recently saw a weird error when running the normalisation adjustment callback.
Basically, when I use it with a model that is functionnal but with the learnlet layers, I use the LearnletAnalysis layers to compute the coefficients just before normalisation. However, I don't only select the details coefficients, which means I have an extra norm_input which hasn't caused any issues yet but should've.

I need to investigate this.

List needed models and logs

Right now, a lot of logs and models clutter my tensorboards.

They are overifitting runs, bug runs, failed runs, unused runs. I should remove them entirely if they don't feature in my article. Maybe it would be nice also to have this table in the readme to allow for ease of reuse.

Implementing Reweighting

The best option for that is to subclass the keras models.
This will give us the possibility to use the model's layers for different operations and therefore to compute both the coefficients and the denoised image.

The reweighting will be given by the following scheme, adapted from modopt:

T_{j,k}^{n+1} = T_{j,k}^{n+1} (1 + |alpha_{j,k}^{n}| / T_{j,k}^{n})

Where the alphas are the previously computed (and thresholded) learnlets coefficients.

Note: this might be done only on the coefficients themselves without a need to go back to the denoised image.

Adapt to TGCC lifestyle

Requirements:

  • write scripts for job submission

examples:
for GPU run

#!/bin/bash 
 #MSUB -r GPU_Job                   # Request name 
 #MSUB -n 1                         # Total number of tasks to use 
 #MSUB -T 1800                      # Elapsed time limit in seconds 
 #MSUB -o example_%I.o              # Standard output. %I is the job id 
 #MSUB -e example_%I.e              # Error output. %I is the job id 
 #MSUB -q hybrid                    # Hybrid partition of GPU nodes
 #MSUB -A projxxxx                  # Project ID  
 set -x 
 cd ${BRIDGE_MSUB_PWD} 
 module load cuda
 ccc_mprun ./a.out

for multi run

#!/bin/bash
#MSUB -r MyJob_Para                # Request name 
#MSUB -n 16                         # Number of tasks to use 
#MSUB -T 3600                      # Elapsed time limit in seconds 
#MSUB -o example_%I.o              # Standard output. %I is the job id 
#MSUB -e example_%I.e              # Error output. %I is the job id 
#MSUB -q <partition>               # Queue
 
set -x
 
ccc_mprun -E '--exclusive' -n 2 ./bin1  &
ccc_mprun -E '--exclusive' -n 1 ./bin2  &
ccc_mprun -E '--exclusive' -n 3 ./bin3  &
ccc_mprun -E '--exclusive' -n 1 ./bin4  &
ccc_mprun -E '--exclusive' -n 4 ./bin5  &
ccc_mprun -E '--exclusive' -n 8 ./bin6  &
ccc_mprun -E '--exclusive' -n 4 ./bin7  &
 
wait  # wait for all ccc_mprun(s) to complete.
  • write training scripts with adaptable locations (for data, checkpoints, logs)

Groupping normalisation

Rn the groupping can't be normalised properly because there is a (voluntary) confusion between the unit norm constraint and the noise std normalisation. This confusion should be removed so that we can properly constrain the groupping conv.

Have a unit tests

For:

  • exact reconstruction
  • wavelet analysis -> verify that we have exact reconstruction when summing and upsampling
  • thresholding layers
  • normalisation layer
  • model instantiation/building/fitting without checking results

Have the wavelet normalisation be done at init time

Right now the normalisation for the wavelets is done completely off-line, hard saved in a file.
It could be done at init (or build) time of the WavAnalysis layer, by doing the same operation carried out in the notebook.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.