I tried to write the test code for this model.But I couldn't get the good result.Here is the code,I use the kitti set for this test.
``from collections import namedtuple
from skimage import io
import tensorflow as tf
import sys
import os
import argparse
import time
import datetime
from utils import *
from trinet import *
from monodepth_dataloader import *
import scipy.misc
import matplotlib.pyplot as plt
parameters = namedtuple('parameters',
'encoder, '
'height, width, '
'batch_size, '
'num_threads, '
'num_epochs '
)
-# forces tensorflow to run on CPU
-# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
parser = argparse.ArgumentParser(description='Argument parser')
""" Arguments related to network architecture"""
parser.add_argument('--width', dest='width', type=int, default=512, help='width of input images')
parser.add_argument('--height', dest='height', type=int, default=256, help='height of input images')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', type=str, default='checkpoint/3DV18/3net',
help='checkpoint directory')
parser.add_argument('--mode', dest='mode', type=int, default=0,
help='Select the demo mode [0: depth-from-mono, 1:view synthesis, 2:stereo]')
parser.add_argument('--filenames_file', type=str, default="new_test_files.txt",
help='path to the filenames text file')
parser.add_argument('--dataset', type=str, default='kitti')
parser.add_argument('--data_path', type=str, default="/media/liuzhu/000450A40005756C/data/")
Norm. factors for visualization
DEPTH_FACTOR = 10
DISP_FACTOR = 6
args = parser.parse_args()
def count_text_lines(file_path):
f = open(file_path, 'r')
lines = f.readlines()
f.close()
return len(lines)
def test(params):
"""Test function."""
height = params.height
width = params.width
placeholders = {'im0': tf.placeholder(tf.float32, [None, None, None, 3], name='im0')}
model = trinet(placeholders, net='resnet50')
loader = tf.train.Saver()
saver = tf.train.Saver()
config = tf.ConfigProto(allow_soft_placement=True)
sess = tf.Session(config=config)
loader.restore(sess, args.checkpoint_dir)
-# SAVER
train_saver = tf.train.Saver()
-# INIT
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
num_test_samples = count_text_lines(args.filenames_file)
print('now testing {} files'.format(num_test_samples))
print('now testing {} files'.format(num_test_samples))
disparities = np.zeros((num_test_samples, params.height, params.width), dtype=np.float32)
with open(args.filenames_file,"r") as files:
str1 =files.readline()
i=0
while str1:
imagepath=args.data_path+str1.split(" ")[0]
image=io.imread(imagepath)
image = cv2.resize(image, (width, height)).astype(np.float32) / 255.
img_batch = np.expand_dims(image, 0)
disp_cr, disp_cl, synt_left, synt_right = sess.run(
[model.disparity_cr, model.disparity_cl, model.warp_left, model.warp_right],
feed_dict={placeholders['im0']: img_batch})
disp = build_disparity(disp_cr, disp_cl)
image = (image * 255).astype(np.uint8)
synt_left = (synt_left * 255).astype(np.uint8)
synt_right = (synt_right * 255).astype(np.uint8)
disp_color = (applyColorMap(disp/DISP_FACTOR, 'plasma')*255).astype(np.uint8)
toShow_C = np.concatenate((image, disp_color), 1)
plt.imsave(os.path.join("imgs/{}_disp.png".format("DIS_"+str(i))), disp/DISP_FACTOR, cmap='plasma')
disparities[i]=(disp/DISP_FACTOR).squeeze()
print(disp/DISP_FACTOR)
i+=1
print(str1)
str1 =files.readline()
print('done.')
print('writing disparities.')
np.save('disparities.npy', disparities)
def main(_):
params = parameters(
encoder="resnet",
height=256,
width=512,
batch_size=8,
num_threads=8,
num_epochs=50)
test(params)
if name == 'main':
tf.app.run()
``
@mattpoggi