Giter Club home page Giter Club logo

mai-vsr-diggers's Introduction

[toc]

MAI-VSR-Diggers

Team "Diggers" winner solution to Mobile AI 2021 Real-Time Video Super-Resolution Challenge

official report paper: https://arxiv.org/abs/2105.08826

Pipeline

6IemqS.png

usage

install

  • Linux machine (you do not need to care about cuda version, only need NVIDIA graphics driver version greater than 418)
  • python 3.7 virtual env
  • pip install megengine -f https://megengine.org.cn/whl/mge.html
  • pip install -r requirements.txt

dataset preparation (REDS)

  • link: https://seungjunnah.github.io/Datasets/reds.html
  • after unzip it ,you need to merge the training and validation dataset(like mmediting), thus total 270(240+30) clip, and remaining 30 clip for test.
  • after merging, your dir should like this:
    • train
      • train_sharp
        • 000
        • ...
        • 240 (the first validation clip, thus clip 000 of validation)
        • ...
        • 269
      • train_sharp_bicubic
        • X4
          • 000
          • ...
          • 269
    • test
      • test_sharp_bicubic
        • X4
          • 000
          • ...
          • 269

Training

  • find the config file: configs/restorers/BasicVSR/mai.py
  • change the first few lines according your situation:

  • start to run:
cd xxx/MAI-VSR-Diggers
python tools/train.py configs/restorers/BasicVSR/mai.py --gpuids 0,1,2,3 -d

support multi gpus training, change to yours, e.g. --gpuids 0 --gpuids 0,2 etc...

you can find output information and checkpoints in .workdirs/...

Testing (now only support REDS dataset)

our checkpoint

use our trained model (generator_module.mge), already inside this repo: ./ckpt/epoch_62 which is only 92kb

it has been trained 62 epochs on 240 clips, it's PSNR on validation dataset(3000 frames) is 27.98

test on valid dataset

find the config file: configs/restorers/BasicVSR/mai_test_valid.py

change first lines for your situation, actually you only need to fix the dataroot

dataroot = "/path2yours/REDS/train/train_sharp_bicubic"
load_path = './ckpt/epoch_62'
exp_name = 'mai_test_for_validation'
eval_part = tuple(map(str, range(240, 270)))

and then , run it:

cd xxx/MAI-VSR-Diggers
python  tools/test.py  configs/restorers/BasicVSR/mai_test_valid.py --gpuids 0 -d

you can find the results in ./workdirs/...

test on test dataset

find the config file: configs/restorers/BasicVSR/mai_test_test.py

change first lines for your situation, actually you only need to fix the dataroot

dataroot = "/path2yours/REDS/test/test_sharp_bicubic"
load_path = './ckpt/epoch_62'
exp_name = 'mai_test_for_test'
eval_part = None

and then , run it:

python  tools/test.py  configs/restorers/BasicVSR/mai_test_test.py --gpuids 0 -d

you can find the results in ./workdirs/...

notice: only support one gpu config for gpuids now

Results on testset

  • all output frames of test dataset produced by our model can be found here: (3000 frames, trained only on 240 training clips):

https://drive.google.com/file/d/1R0DDHmV8jZW_iYJQZksO2RWkZTrAYYPi/view?usp=sharing

get the tflite model

Overall pipeline thinking

  • train the model by megengine framework(something like pytorch, tensorflow....)
  • definite same model by tensorflow (same size, same deal pipeline...)
  • load xxx.mge -> numpy.ndarray -> tf.keras.Model
  • convert the tf.keras.Model to tflite using tensorflow

one line to get .tflite

model.tflite

cd xxx/tflite/
python main.py  --mgepath  /xxxxxx/ckpt/epoch_62/generator_module.mge

model_none.tflite

cd xxx/tflite/
python main.py  --mgepath  /xxxxxx/ckpt/epoch_62/generator_module.mge  -n

notice that to use absolute path

you will get tflite files in the dir xxx/tflite/xxx

and we have supported our pre-built model.tflite and model_none.tflite in ckpt dir

testing on custum data using tflite

you can refer to #2.

mai-vsr-diggers's People

Contributors

feynman1999 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

mai-vsr-diggers's Issues

Vs esrgan?

How does this model compare with ESRGAN? If we use that tobupscale every single frame. Runtime vs output quality.

Vsr model

Hi, Is this the same VSR model used in AI benchmark app?

How do you use tflite model for camera video frames?

Hello,

First, I would like to appreciate your work.

Actually, I have been trying to implement this model for my custom project. I have tried to obtain the single frame result from model_none.tflite, but it takes 10 frames, which I did. However, the result does not look good. can you tell me why?
download

I have used the following:

  1. For stacking all frames in the color channel which gives (1, 180, 320, 300)
all_frames = sorted(glob.glob('/content/drive/MyDrive/Mobile_communication/val_sharp_bicubic/X4/000/*.png'))

read_1 = tf.io.read_file(all_frames[0])
read_1 = tf.image.decode_jpeg(read_1, channels=3)
stacked = tf.Variable(np.empty((1,read_1.shape[0],read_1.shape[1],read_1.shape[2]), dtype=np.float32))

#for ind in range(1): 
for ind in all_frames:
  test_img_path = ind
  lr1 = tf.io.read_file(test_img_path)
  lr = tf.image.decode_jpeg(lr1, channels=3)
  lr = tf.expand_dims(lr, axis=0)
  lr = tf.cast(lr, tf.float32)
  stacked = tf.concat([stacked, lr], axis=-1)
stacked = stacked[:,:,:,3:]
print(stacked.shape)
  1. Giving 10 frames for VSR by your tflite model which gives (720, 1280, 30)
frames_10 = stacked[:,:,:,:30]
vsr_model_path = './MAI-VSR-Diggers/ckpt/model_none.tflite'
#vsr_model_path = './MAI-VSR-Diggers/ckpt/model.tflite'

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=vsr_model_path)

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details, '\n',output_details)

# Run the model
interpreter.resize_tensor_input(input_details[0]['index'], [1, 180, 320, 30], strict=False)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], in_frames)
interpreter.invoke()

# Extract the output and postprocess it
output_data = interpreter.get_tensor(output_details[0]['index'])
vsr = tf.squeeze(output_data, axis=0)
print(vsr.shape)
  1. Display 1 frame of SR
frame_1 = stacked[:,:,:,:3]
#lr = tf.squeeze(frame_1, axis=0)
lr = tf.cast(tf.squeeze(frame_1, axis=0), tf.uint8)
print(lr.shape)
#Image.fromarray(np.asarray(lr)).show()
plt.figure(figsize = (5,6))
plt.title('LR')
plt.imshow(lr.numpy());

tensor = vsr[:,:,:3]
shape = tensor.shape
image_scaled = minmax_scale(tf.reshape(tensor,shape=[-1]), feature_range=(0,255)).reshape(shape)
tensor = tensor/255
print(tensor.shape)
plt.figure(figsize=(25, 15))
plt.subplot(1, 2, 1)        
plt.title(f'VSR (x4)')
plt.imshow(tensor.numpy());

bicubic = tf.image.resize(lr, [720, 1280], tf.image.ResizeMethod.BICUBIC)
bicubic = tf.cast(bicubic, tf.uint8)
plt.subplot(1, 2, 2)   
plt.title('Bicubic')
plt.imshow(bicubic.numpy());

training error goes nan

2022-08-10 22:49:58,911 - edit - INFO - training gpus num: 2
2022-08-10 22:49:58,912 - edit - INFO - init distributed process group 0 / 2
2022-08-10 22:49:58,915 - edit - INFO - init distributed process group 1 / 2
2022-08-10 23:16:08,691 - rank0_edit - INFO - SRManyToManyDataset dataset load ok, mode: train len:24000
2022-08-10 23:16:08,691 - rank0_edit - INFO - use repeatdataset, repeat times: 1
2022-08-10 23:16:08,694 - rank0_edit - INFO - model: BasicVSR_v5 's total parameter nums: 23371
2022-08-10 23:16:08,698 - rank0_edit - INFO - syncing the model's parameters...
2022-08-10 23:16:08,991 - rank0_edit - INFO - SRManyToManyDataset dataset load ok, mode: eval len:3000
2022-08-10 23:16:08,992 - rank0_edit - INFO - 1500 iters for one epoch, trained iters: 0, total iters: 600000
2022-08-10 23:16:08,992 - rank0_edit - INFO - Start running, work_dir: ./workdirs/mai_training/20220810_224958, workflow: train, max epochs : 400
2022-08-10 23:16:08,992 - rank0_edit - INFO - registered hooks: [<edit.core.hook.logger.text.TextLoggerHook object at 0x7f89e5a0a290>, <edit.core.hook.checkpoint.checkpoint.CheckpointHook object at 0x7f89e5a0a2d0>, <edit.core.hook.evaluation.eval_hooks.EvalIterHook object at 0x7f89e7f49750>]
2022-08-10 23:16:25,548 - rank0_edit - INFO - epoch: 0, losses: [0.00301], losses_ma: [0.00301], iter: 4
2022-08-10 23:16:34,997 - rank0_edit - INFO - epoch: 0, losses: [0.00328], losses_ma: [0.00314], iter: 9
2022-08-10 23:16:44,411 - rank0_edit - INFO - epoch: 0, losses: [0.00399], losses_ma: [0.00343], iter: 14
2022-08-10 23:16:52,555 - rank0_edit - INFO - epoch: 0, losses: [0.00506], losses_ma: [0.00383], iter: 19
2022-08-10 23:17:02,273 - rank0_edit - INFO - epoch: 0, losses: [0.00280], losses_ma: [0.00363], iter: 24
2022-08-10 23:17:10,873 - rank0_edit - INFO - epoch: 0, losses: [0.00386], losses_ma: [0.00367], iter: 29
2022-08-10 23:17:19,736 - rank0_edit - INFO - epoch: 0, losses: [0.00313], losses_ma: [0.00359], iter: 34
2022-08-10 23:17:28,804 - rank0_edit - INFO - epoch: 0, losses: [0.00358], losses_ma: [0.00359], iter: 39
2022-08-10 23:17:38,345 - rank0_edit - INFO - epoch: 0, losses: [0.00376], losses_ma: [0.00361], iter: 44
2022-08-10 23:17:46,827 - rank0_edit - INFO - epoch: 0, losses: [0.00311], losses_ma: [0.00356], iter: 49
2022-08-10 23:17:55,623 - rank0_edit - INFO - epoch: 0, losses: [0.00406], losses_ma: [0.00360], iter: 54
2022-08-10 23:18:04,857 - rank0_edit - INFO - epoch: 0, losses: [0.00326], losses_ma: [0.00357], iter: 59
2022-08-10 23:18:14,276 - rank0_edit - INFO - epoch: 0, losses: [0.00368], losses_ma: [0.00358], iter: 64
2022-08-10 23:18:23,079 - rank0_edit - INFO - epoch: 0, losses: [0.00392], losses_ma: [0.00361], iter: 69
2022-08-10 23:18:31,822 - rank0_edit - INFO - epoch: 0, losses: [0.00393], losses_ma: [0.00363], iter: 74
2022-08-10 23:18:40,337 - rank0_edit - INFO - epoch: 0, losses: [0.00429], losses_ma: [0.00367], iter: 79
2022-08-10 23:18:50,203 - rank0_edit - INFO - epoch: 0, losses: [0.00378], losses_ma: [0.00368], iter: 84
2022-08-10 23:18:58,560 - rank0_edit - INFO - epoch: 0, losses: [0.00336], losses_ma: [0.00366], iter: 89
2022-08-10 23:19:08,135 - rank0_edit - INFO - epoch: 0, losses: [0.00406], losses_ma: [0.00368], iter: 94
2022-08-10 23:19:16,653 - rank0_edit - INFO - epoch: 0, losses: [0.00396], losses_ma: [0.00369], iter: 99
2022-08-10 23:19:26,492 - rank0_edit - INFO - epoch: 0, losses: [0.00315], losses_ma: [0.00367], iter: 104
2022-08-10 23:19:34,738 - rank0_edit - INFO - epoch: 0, losses: [0.00389], losses_ma: [0.00368], iter: 109
2022-08-10 23:19:43,368 - rank0_edit - INFO - epoch: 0, losses: [0.00297], losses_ma: [0.00365], iter: 114
2022-08-10 23:19:52,057 - rank0_edit - INFO - epoch: 0, losses: [0.00436], losses_ma: [0.00368], iter: 119
2022-08-10 23:20:02,484 - rank0_edit - INFO - epoch: 0, losses: [0.00338], losses_ma: [0.00367], iter: 124
2022-08-10 23:20:11,059 - rank0_edit - INFO - epoch: 0, losses: [0.00386], losses_ma: [0.00367], iter: 129
2022-08-10 23:20:20,270 - rank0_edit - INFO - epoch: 0, losses: [651753408617431171072.00000], losses_ma: [24139015133978931200.00000], iter: 134
2022-08-10 23:20:28,060 - rank0_edit - INFO - epoch: 0, losses: [nan], losses_ma: [nan], iter: 139

The losses seems to start low and then it goes to nan.
For training i used -
python tools/train.py configs/restorers/BasicVSR/mai.py --gpuids 0,1 -d

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.