Giter Club home page Giter Club logo

Comments (3)

elliottwu avatar elliottwu commented on May 21, 2024 1

Hi,
I have not cleaned the code for the keypoint depth evaluation yet, but I can share the snippets for cropping and evaluation.

Also, the current released celeba model crops too much, and keypoints land outside of the image. For the evaluation, we used another model that was trained on slightly larger faces (8/7 of the height and width of the current crops provided), for which you might need to redo the cropping using the provided bounding boxes. I hope this helps for now.

crop_3DFAW.py

import os
from glob import glob
import numpy as np
import shutil
import cv2


split = np.loadtxt('/users/szwu/DepthNets/depthnet-pytorch/data/3DFAW/list_valid_test.txt', dtype=str, delimiter=',')

rootin_im = '/users/szwu/DepthNets/depthnet-pytorch/data/3DFAW/valid_img/'
rootin_kpt = '/users/szwu/DepthNets/depthnet-pytorch/data/3DFAW/valid_lm/'
rootout_ori_im = '/users/szwu/DepthNets/depthnet-pytorch/data/test_depthcorr/ori/img/'
rootout_ori_kpt = '/users/szwu/DepthNets/depthnet-pytorch/data/test_depthcorr/ori/kpts/'
rootout_crop_im = '/users/szwu/DepthNets/depthnet-pytorch/data/test_depthcorr/cropped/img/'
rootout_crop_kpt = '/users/szwu/DepthNets/depthnet-pytorch/data/test_depthcorr/cropped/kpts/'

oriens = ['left','center','right']
for ori in oriens:
    for rootout in [rootout_ori_im, rootout_ori_kpt, rootout_crop_im, rootout_crop_kpt]:
        dirout = os.path.join(rootout, ori)
        if not os.path.isdir(dirout):
            os.makedirs(dirout)

for fid, sp, ori in split:
    if ori in oriens:
        im_fname = fid + '.jpg'
        im_fpath = os.path.join(rootin_im, im_fname)
        shutil.copyfile(im_fpath, os.path.join(rootout_ori_im, ori, im_fname))

        kpt_fname = fid + '_lm.csv'
        kpt_fpath = os.path.join(rootin_kpt, kpt_fname)
        shutil.copyfile(kpt_fpath, os.path.join(rootout_ori_kpt, ori, kpt_fname))

        kpt = np.loadtxt(kpt_fpath, delimiter=',')

        im = cv2.imread(im_fpath)
        a2 = int((kpt.max(0)-kpt.min(0))[:2].mean()*1)
        im = cv2.copyMakeBorder(im,a2,a2,a2,a2,cv2.BORDER_REPLICATE)

        leye = kpt[36:42,:2].mean(0)
        reye = kpt[42:48,:2].mean(0)
        lmouth = kpt[48,:2]
        rmouth = kpt[54,:2]
        nose = kpt[30,:2]

        c1 = (leye + reye + lmouth + rmouth)/4
        c2 = nose - 2*(nose - c1)
        xc, yc = (c2 + 0.5*((leye + reye)/2 - c1)).astype(int)

        x1 = xc-a2 + a2
        x2 = x1+a2*2
        y1 = yc-a2 +int(0.1*a2) + a2
        y2 = y1+a2*2

        crop = im[y1:y2, x1:x2, :]
        crop = cv2.resize(crop, (256,256))
        cv2.imwrite(os.path.join(rootout_crop_im, ori, im_fname), crop)

        s = 256/(a2*2)
        kpt_xy = kpt[:,:2] - np.array([[x1, y1]]) +a2
        kpt_xy = kpt_xy *s
        kpt_z = kpt[:,2:] *s
        np.savetxt(os.path.join(rootout_crop_kpt, ori, kpt_fname), np.concatenate([kpt_xy, kpt_z], 1), delimiter=',')

eval_kpt_depth.py

import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
from glob import glob


# with border (reported in paper)
result_dir = '/scratch/shared/nfs1/szwu/all_results/finetuned_celeba_test_depthcorr_080_db0.7_crop1.4/'
# result_dir = '/scratch/shared/nfs1/szwu/all_results/3dfaw_test_depthcorr_100_db0.7_crop1.4/'
source_dir = '/scratch/shared/nfs1/szwu/data/test_depthcorr'

kpts_fpaths = sorted(glob(os.path.join(source_dir, 'cropped/kpts/center/*.csv')))
kpts = np.array([np.loadtxt(kpts_fpath, delimiter=',') for kpts_fpath in kpts_fpaths])

crop = 1.4
kpts_xy = kpts[:,:,:2]
s = 256/(64*crop)
kpts_xy = kpts_xy / s
p = 64*(crop-1)/2
kpts_xy = kpts_xy - p
kpts_xy[:,:,1] -= 4

ori_kpts_fpaths = sorted(glob(os.path.join(source_dir, 'ori/kpts/center/*.csv')))
ori_kpts = np.array([np.loadtxt(kpts_fpath, delimiter=',') for kpts_fpath in ori_kpts_fpaths])

im_fpaths = sorted(glob(os.path.join(result_dir, 'input/*.png')))[:75]
depth_fpaths = sorted(glob(os.path.join(result_dir, 'depth/*.png')))[:75]
depths = np.array([cv2.imread(depth_fpath, -1)/65535. for depth_fpath in depth_fpaths])
ori_im_fpaths = sorted(glob(os.path.join(source_dir, 'ori/img/center/*.jpg')))
ori_ims = np.array([cv2.imread(im_fpath)[:,:,::-1] for im_fpath in ori_im_fpaths])


def compute_covar(preds, actuals, n_kps=66):
    return np.sum(np.diag(np.abs(np.corrcoef(preds, actuals, rowvar=0)[:n_kps,n_kps:])))

avg_szs = np.array([sum(im.shape[:2])/2 for im in ori_ims])
gt_z_rescaled = ori_kpts[:,:,2] / avg_szs[:,None]

depths_t = torch.FloatTensor(depths).unsqueeze(1)
kpts_xy_t = torch.FloatTensor(kpts_xy).unsqueeze(1)/32-1

depths_t = torch.nn.functional.pad(depths_t, (32,32,32,32), mode='replicate')
z_sampled = torch.nn.functional.grid_sample(depths_t, kpts_xy_t/2)[:,0,0,:]
z_sampled = z_sampled - z_sampled.mean(1,keepdim=True)

score = compute_covar(z_sampled, gt_z_rescaled)
print(score)

from unsup3d.

dafuny avatar dafuny commented on May 21, 2024

Hi Wu,
Thanks for releasing your impressive work. Recently I am trying to reproduce the results of Table 5 and have got the 3DFAW dataset. I wonder know how you perform the data processing and do the evaluation. I guess that you refer to the depth-net repo https://github.com/joelmoniz/DepthNets/tree/master/depthnet-pytorch, but it is still confusing that how to crop the image, preserve the key points location and calculate the metrics. Is it possible for you to release this part of code or show more details?

Hi @JesseZhang92,I'm also very interested in the 3DFAW valuation experienment,but follow Wu's instructions,I got a bad result,and I wonder if the cropped size is wrong or other steps .And if you have produced the result,could you share your code details to me?Very appreciated if you could give me a reply.

from unsup3d.

Heng14 avatar Heng14 commented on May 21, 2024

Hi, I am also working on the evaluation of the 3DFAW dataset. Did you guys figure it out?
Thank you so much!

from unsup3d.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.