Thanks for your code! I found that the estimated key points do not correspond to ground-truth when I visualize it directly from preds. What should I do to recover the estimated 3D coordinates to a 2D image. btw, I'm a beginner, please don't hesitate to teach!
with torch.no_grad():
for itr, (inputs, targets, meta_info) in enumerate(tqdm(tester.batch_generator,ncols=150)):
# forward
start = time.time()
out = tester.model(inputs, targets, meta_info, 'test')
end = time.time()
joint_coord_out = out['joint_coord'].cpu().numpy()
inv_trans = out['inv_trans'].cpu().numpy() #
joint_vaild = out['joint_valid'].cpu().numpy()
preds['joint_coord'].append(joint_coord_out)
preds['inv_trans'].append(inv_trans)
preds['joint_valid'].append(joint_vaild)
timer.append(end-start)
# visualization
# focal = meta_info['focal'][0]
# princpt = meta_info['princpt'][0]
# for j in range(42):
# joint_coord_out[0][j,:2] = trans_point2d(joint_coord_out[0][j,:2],inv_trans[0])
# joint_coord_out[0][:,2] = (joint_coord_out[0][:,2]/cfg.output_hm_shape[0] * 2 - 1) * (cfg.bbox_3d_size/2)
# joint_coord_out[0][:21,2] += float(targets['rel_root_depth'][0])
# joint_coord_out[0] = pixel2cam(joint_coord_out[0], focal, princpt)
plt.imshow(inputs['img'][0].permute(1,2,0))
plt.scatter(joint_coord_out[0][:21,0],joint_coord_out[0][:21,1])
plt.scatter(targets['joint_coord'][0][:21,0],targets['joint_coord'][0][:21,1])
plt.savefig('./visualization/result'+str(itr)+'.png')
plt.close()