Giter Club home page Giter Club logo

nlut's Introduction

NLUT: Neural-based 3D Lookup Tables for Video Photorealistic Style Transfer

Overview

NLUT(see our paper and project page )is a super fast photorealistic style transfer method for video. We build a neural network to generate a stylized 3D LUT. The goal is to realize fast photorealistic style transfer for video. Specifically, we train the neural network that produces 3D LUT on a large dataset and then fine-tune it in test-time training to generate a stylized 3D LUT of a specific style image and video content. Although our method needs fine-tuning when used, it is more effective than other methods and is super fast in video processing. For example, it can process 8K video in less than 2 milliseconds. In the future, we will explore ways to generate 3D LUTs in arbitrary styles even more quickly.

Preparation

Enviroment

Please ensure that you have correctly configured the following environment and you can quickly install the required environment through the following command.

pip install -r requirements.txt
  • matplotlib==3.5.1
  • numpy==1.22.4
  • opencv_python==4.5.5.62
  • Pillow==9.4.0
  • plotly==5.13.0
  • scipy==1.7.3
  • setuptools==58.0.4
  • torch==1.10.1
  • torchvision==0.11.2
  • tqdm==4.62.3

The fast deployment of 3D LUT relies on the CUDA implementation of trilinear interpolation in Image-Adaptive-3DLUT. To install their trilinear library:

cd trilinear_cpp
sh setup.sh

data

Training dataset.

You can download the training dataset through the link below

pre-trained checkpoint: link:https://pan.baidu.com/s/1VddHbq2cBy5RcKOp8S5eSg extraction code:1234 or google drive: https://drive.google.com/drive/folders/1YqCKnfqzOPtmwdYAziGZMQ79iAI0_0ur

training

All the appropriate hyper-parameters have been set as default,Only the content_path and style_path needs to be modified before training.

You can train with the following commands

python train.py --content_dir <path> --style_dir <path>

test

We have set the appropriate hyper-parameters as the default,Only the content_path and style_path needs to be modified before testing.

generate stylized image

python inference_finetuning_image.py --content_path <path> --style_path <path> --output_path <path>

generate stylized video

python inference_finetuning_video.py --content_path <path> --style_path <path> --src_video <path> --dst_video <path>

License

This algorithm is licensed under the MIT License.See the LICENSE file for details.

nlut's People

Contributors

semchan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

nlut's Issues

What causes nan problems

iter 3700 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5781 loss_s: 2.8997 loss_mse: 0.0280 losses: 3.4793
iter 3800 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5784 loss_s: 2.9005 loss_mse: 0.0281 losses: 3.4805
iter 3900 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5783 loss_s: 2.8970 loss_mse: 0.0281 losses: 3.4768
iter 4000 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5776 loss_s: 2.8929 loss_mse: 0.0281 losses: 3.4721
iter 4100 time/iter: 0.49 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5773 loss_s: 2.8916 loss_mse: 0.0281 losses: 3.4705
iter 4200 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5775 loss_s: 2.8877 loss_mse: 0.0281 losses: 3.4668
iter 4300 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5779 loss_s: 2.8855 loss_mse: 0.0281 losses: 3.4649
iter 4400 time/iter: 0.49 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5782 loss_s: 2.8828 loss_mse: 0.0282 losses: 3.4626
iter 4500 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4600 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4700 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4800 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4900 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 5000 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 5100 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan

推理报错

FileNotFoundError: [Errno 2] No such file or directory: './finetuning_train/checkpoint/9_finetuning_style_lut.pth'

Unable to produce any kind of stylized image

Hi folks,

I'm trying to produce any kind of stylized image with this repo and, unfortunately, have been unable to do so even with the default image data in the data/ folder. I'm issuing the below command:

python3 inference_finetuning_image.py --pretrained ./experiments/336999_style_lut.pth --resume ./experiments/336999_style_lut.pth 

The above command should operate on the default data specified in parameter_finetuning.py. It seems the generated output, however, is pretty much the same as the input content image.

Am I doing something wrong with the usage of the stylization script itself, or is something wrong with the pre-trained checkpoitn that was previously available (336999_style_lut.pth)? As observed below, the losses during fine-tuning seem pretty high as well, but I'm more surprised that it seems like absolutely no stylization is applied to the output video.

now device is cuda:0
n=2048 s=32 w=32 ++
Total params: 22.01M
--------loading checkpoint----------
=> loading checkpoint './experiments/336999_style_lut.pth'
/home/akshaypa/.local/lib/python3.8/site-packages/torch/nn/functional.py:1795: UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
  warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")
iter 0   time/iter: 0.05  lr: 0.000100 loss_mn: 0.0009 loss_c: 0.0000   loss_s: 2.2449 losses: 2.2458 
iter 10   time/iter: 0.56  lr: 0.000100 loss_mn: 0.0693 loss_c: 0.0000   loss_s: 2.2448 losses: 2.3141 
iter 20   time/iter: 0.56  lr: 0.000100 loss_mn: 0.0405 loss_c: 0.0000   loss_s: 2.2447 losses: 2.2852 
iter 30   time/iter: 0.56  lr: 0.000100 loss_mn: 0.0283 loss_c: 0.0000   loss_s: 2.2446 losses: 2.2729 
iter 39   time/iter: 0.43  lr: 0.000100 loss_mn: 0.0223 loss_c: 0.0000   loss_s: 2.2444 losses: 2.2667 
n=2048 s=32 w=32 ++
Total params: 22.01M
--------loading checkpoint----------
=> loading checkpoint './experiments/336999_style_lut.pth'
save to: data/city2.jpg

cube file

Any chance I can save a cube file from the LUT computed by the model?

checkpoint file

Hello, thanks for your awesome work! I've downloaded the pre-trained checkpoint file, could you tell me how to use it?

Unexpected key(s) in state_dict

加载下载的预训练模型时报错 Unexpected key(s) in state_dict: "blurer.op.1.weight", "SB1.conv1.conv2d.weight", "SB1.conv1.conv2d.bias", "SB1.conv1.bn.weight", "SB1.conv1.bn.bias", "SB1.conv1.bn.running_mean", "SB1.conv1.bn.running_var", "SB1.conv1.bn.num_batches_tracked", "SB1.conv2.conv2d.weight", "SB1.conv2.conv2d.bias", "SB1.conv2.bn.weight", "SB1.conv2.bn.bias", "SB1.conv2.bn.running_mean", "SB1.conv2.bn.running_var", "SB1.conv2.bn.num_batches_tracked". ,经检查 NLUTNet 模型定义时确实没有 blurer 和 SB1 相关的定义,辛苦作者 check 下代码和参数文件,谢谢!

inference

How can just inference the model after finetuning I do not want each test to do finetuning?

reuplaod pre-trained checkpoint

Hi, I can't donwload pre-trained checkpoint from baidu. Need to have account. Can you please reupload this file to another host without login?
Thanks

problem with checkpoint

Hi, have this problem

now device is cuda:0
n=2048 s=32 w=32 ++
Total params: 22.01M
--------loading checkpoint----------
=> loading checkpoint '336999_style_lut.pth'
Traceback (most recent call last):
File "W:\NLUT\inference_finetuning_video.py", line 377, in
finetuning_train(opt, original, example)
File "W:\NLUT\inference_finetuning_video.py", line 156, in finetuning_train
model.load_state_dict(checkpoint['state_dict'])
File "C:\ProgramData\Anaconda3\envs\NLUT-styles\lib\site-packages\torch\nn\modules\module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for NLUTNet:
Unexpected key(s) in state_dict: "blurer.op.1.weight", "SB1.conv1.conv2d.weight", "SB1.conv1.conv2d.bias", "SB1.conv1.bn.weight", "SB1.conv1.bn.bias", "SB1.conv1.bn.running_mean", "SB1.conv1.bn.running_var", "SB1.conv1.bn.num_batches_tracked", "SB1.conv2.conv2d.weight", "SB1.conv2.conv2d.bias", "SB1.conv2.bn.weight", "SB1.conv2.bn.bias", "SB1.conv2.bn.running_mean", "SB1.conv2.bn.running_var", "SB1.conv2.bn.num_batches_tracked".

Probem with trilinear building

When I try to build trilinear module, I catch the next error:

RuntimeError:
The detected CUDA version (11.8) mismatches the version that was used to compile
PyTorch (10.2). Please make sure to use the same CUDA versions.

I am running next lines in Colab notebook:

!git clone https://github.com/semchan/NLUT
!pip install -r /content/NLUT/requirements.txt
%cd /content/NLUT/trilinear_cpp
!sh setup.sh

Full trace:

Including CUDA code.
running install
running bdist_egg
running egg_info
writing trilinear.egg-info/PKG-INFO
writing dependency_links to trilinear.egg-info/dependency_links.txt
writing top-level names to trilinear.egg-info/top_level.txt
/usr/local/lib/python3.9/dist-packages/torch/utils/cpp_extension.py:381: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'trilinear.egg-info/SOURCES.txt'
writing manifest file 'trilinear.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
Traceback (most recent call last):
File "/content/NLUT/trilinear_cpp/setup.py", line 7, in
setup(
File "/usr/local/lib/python3.9/dist-packages/setuptools/init.py", line 153, in setup
return distutils.core.setup(**attrs)
File "/usr/lib/python3.9/distutils/core.py", line 148, in setup
dist.run_commands()
File "/usr/lib/python3.9/distutils/dist.py", line 966, in run_commands
self.run_command(cmd)
File "/usr/lib/python3.9/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/local/lib/python3.9/dist-packages/setuptools/command/install.py", line 67, in run
self.do_egg_install()
File "/usr/local/lib/python3.9/dist-packages/setuptools/command/install.py", line 109, in do_egg_install
self.run_command('bdist_egg')
File "/usr/lib/python3.9/distutils/cmd.py", line 313, in run_command
self.distribution.run_command(command)
File "/usr/lib/python3.9/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/local/lib/python3.9/dist-packages/setuptools/command/bdist_egg.py", line 164, in run
cmd = self.call_command('install_lib', warn_dir=0)
File "/usr/local/lib/python3.9/dist-packages/setuptools/command/bdist_egg.py", line 150, in call_command
self.run_command(cmdname)
File "/usr/lib/python3.9/distutils/cmd.py", line 313, in run_command
self.distribution.run_command(command)
File "/usr/lib/python3.9/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/local/lib/python3.9/dist-packages/setuptools/command/install_lib.py", line 11, in run
self.build()
File "/usr/lib/python3.9/distutils/command/install_lib.py", line 109, in build
self.run_command('build_ext')
File "/usr/lib/python3.9/distutils/cmd.py", line 313, in run_command
self.distribution.run_command(command)
File "/usr/lib/python3.9/distutils/dist.py", line 985, in run_command
cmd_obj.run()
File "/usr/local/lib/python3.9/dist-packages/setuptools/command/build_ext.py", line 79, in run
_build_ext.run(self)
File "/usr/local/lib/python3.9/dist-packages/Cython/Distutils/old_build_ext.py", line 186, in run
_build_ext.build_ext.run(self)
File "/usr/lib/python3.9/distutils/command/build_ext.py", line 340, in run
self.build_extensions()
File "/usr/local/lib/python3.9/dist-packages/torch/utils/cpp_extension.py", line 404, in build_extensions
self._check_cuda_version()
File "/usr/local/lib/python3.9/dist-packages/torch/utils/cpp_extension.py", line 781, in _check_cuda_version
raise RuntimeError(CUDA_MISMATCH_MESSAGE.format(cuda_str_version, torch.version.cuda))
RuntimeError:
The detected CUDA version (11.8) mismatches the version that was used to compile
PyTorch (10.2). Please make sure to use the same CUDA versions.

libtorch_cuda_cu.so

ImportError: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory

image or video data

Hello, could you please provide some test data of images or videos presented in your paper? Thank you very much!

pre-trained checkpoint can not download

Hello,
Thank you for sharing the code and weight .unfortunately, your pre-trained checkpoint can not download because it needs registration and this site can not allow overseas countries. could you please share it on another website such as google drive?

nan

请问为什么会出现nan的问题呀
iter 3500 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0015 loss_c: 0.5787 loss_s: 2.9061 loss_mse: 0.0281 losses: 3.4864
iter 3600 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0015 loss_c: 0.5785 loss_s: 2.9014 loss_mse: 0.0281 losses: 3.4815
iter 3700 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5781 loss_s: 2.8997 loss_mse: 0.0280 losses: 3.4793
iter 3800 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5784 loss_s: 2.9005 loss_mse: 0.0281 losses: 3.4805
iter 3900 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5783 loss_s: 2.8970 loss_mse: 0.0281 losses: 3.4768
iter 4000 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5776 loss_s: 2.8929 loss_mse: 0.0281 losses: 3.4721
iter 4100 time/iter: 0.49 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5773 loss_s: 2.8916 loss_mse: 0.0281 losses: 3.4705
iter 4200 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5775 loss_s: 2.8877 loss_mse: 0.0281 losses: 3.4668
iter 4300 time/iter: 0.50 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5779 loss_s: 2.8855 loss_mse: 0.0281 losses: 3.4649
iter 4400 time/iter: 0.49 lr: 0.000100 loss_mn: 0.0016 loss_c: 0.5782 loss_s: 2.8828 loss_mse: 0.0282 losses: 3.4626
iter 4500 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4600 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4700 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4800 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 4900 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 5000 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan
iter 5100 time/iter: 0.49 lr: 0.000100 loss_mn: nan loss_c: nan loss_s: nan loss_mse: nan losses: nan

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.