iResnet
Non official pytorch implementation of i-Resnet, invertible residual networks.
Look at invertible-2dim-logdet for example on how to use for linear examples
Non official pytorch implementation of i-Resnet, invertible residual networks.
License: MIT License
Non official pytorch implementation of i-Resnet, invertible residual networks.
Look at invertible-2dim-logdet for example on how to use for linear examples
Hi Jarrel,
I'm trying to adapt your code to fit 8 Gaussian mixtures like what experiment they showed in their paper (unlike you fitting the latent z with 4 Gaussian mixtures, I'm trying to fit the latent z with a Gaussian prior). However, loss, logdet, and pz all blow up (na) after some point.
Could you guess what the problem is and how to make the training more stable? Thanks.
Modules with Gouk's version of spectral norm cannot be saved by torch.save
. However, the native SN of PyTorch don't has this problem.
Reproduce:
import torch
import torch.nn as nn
from iResnet import SpectralNormGouk as sn
model = sn.spectral_norm(nn.Conv2d(3, 3, 3))
torch.save(model, 'gouk.pth')
Error:
'''
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/var/folders/r4/n7vnlt1528b_18nsjk7v_02w0000gn/T/ipykernel_81895/1404222307.py in <module>
1 model = sn.spectral_norm(nn.Conv2d(3, 3, 3))
2
----> 3 torch.save(model, 'gouk.pth')
/opt/homebrew/Caskroom/miniforge/base/envs/ML/lib/python3.8/site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
377 if _use_new_zipfile_serialization:
378 with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 379 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
380 return
381 _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
/opt/homebrew/Caskroom/miniforge/base/envs/ML/lib/python3.8/site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol)
482 pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
483 pickler.persistent_id = persistent_id
--> 484 pickler.dump(obj)
485 data_value = data_buf.getvalue()
486 zip_file.write_record('data.pkl', data_value, len(data_value))
AttributeError: Can't pickle local object 'SpectralNorm.apply.<locals>.<lambda>'
'''
Your code is really helpful to understand how iResNet works.
Thanks for writing this code.
However, when I was trying to run the CNN version code jupyter notebook,
It gave me the wrong result on the evaluation phase (when activating evaluation mode with net.eval()) such that after a few iterations, the model even cannot reconstruct the inputs and the latent standard of test data diverges. (I am using DataParallel and Do u think the problem comes from this?)
Did you get the right result??
Thanks in advance for your reply
Hello,
Thank you for this code!
In addition to latent, I need the estimated density value for each data point.
How can I access that? I don't want to spend time on details of this paper in this stage of work.
Thank you in advance!
Pytorch has it built-in spectral normalization, i.e., the torch.nn.utils.spectral_norm(). Is there any reason you don't use that?
Hi, i see your implementation of the invertible resnet. But i am still confused about the classification part in the paper? Do they have two seperate models for classification and generalisation (or the classification model is just the forward part of the invertilble model)?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.