Giter Club home page Giter Club logo

Comments (12)

liguohao96 avatar liguohao96 commented on June 14, 2024 3

@alchemi5t the pytorch output is UV position map which is a special image with XYZ stored inside pixels instead of RGB. The pytorch version should produce the same result as tf version, so you could use the offical prnet utils with transfering pytorch tensor to numpy.

BTW, don't forget to multiply the output by 256x1.1.

from pytorch-prnet.

LucienXian avatar LucienXian commented on June 14, 2024 1

@alchemi5t Yes, with the official prnet. When we use the pytorch version, the output is the same numpy array.
You can try it:
torch_input = torch.from_numpy(img_bchw)
torch_out = torch_model(torch_input).cpu().detach().numpy()
torch_out = np.transpose(torch_out, (0, 2, 3, 1)).squeeze()
cropped_pos = torch_out * resolution_op * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T

from pytorch-prnet.

liguohao96 avatar liguohao96 commented on June 14, 2024

TLDR: you can change the code in test/eval.py.

test/eval.py is designed to compare the output of tf and pytorch with a random input. It just create a random input and do inference in tf and pytorch to see if they produce the same output.

Change random input to image will build a fully functioning pytorch-prnet.

PLS, let me know if you need more detailed help.

from pytorch-prnet.

LucienXian avatar LucienXian commented on June 14, 2024

@liguohao96 Thank you for your quick reply!!! I have used your code to make inferences. And refer to https://github.com/YadiraF/PRNet for pre-processing of the image, but I still can't reconstruct a valid 3D object, get a valid .obj file like the original repo. Have you tried to reconstruct the face with the results of test/eval.py?

from pytorch-prnet.

LucienXian avatar LucienXian commented on June 14, 2024

I have solved the problem. Thanks a lot!

from pytorch-prnet.

alchemi5t avatar alchemi5t commented on June 14, 2024

@LucienXian how did you write the obj file out? from what i figured out, the pytorch output is the uv coords. did you use the official prnet utils to write the obj file?

@liguohao96 any help is appreciated.

from pytorch-prnet.

alchemi5t avatar alchemi5t commented on June 14, 2024

@liguohao96 The tf prnet utils requires vertices,colors, triangles and textures along with uv_coords. In the original tf project, they are getting these parameters from the prnet object. how would i get that in your pytorch version.

FYI, all i have managed to get is the torch_watched_out(uv_coords). I am trying to get textured models out.

from pytorch-prnet.

alchemi5t avatar alchemi5t commented on June 14, 2024
torch_out = torch_model(image_tensor)
torch_out = np.transpose(
    torch_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_watched_out = torch_out
cropped_pos = torch_out * 256 * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T
vertices=get_vertices(cropped_pos)
colors=get_colors(random_image,vertices)
write_obj_with_colors("test_color", vertices, triangles, colors)

I created an obj with this pipeline, but the model is totally botched. Any idea what i might have done wrong?

@LucienXian @liguohao96 Also, Thank you for the prompt responses. This has been very helpful!

from pytorch-prnet.

LucienXian avatar LucienXian commented on June 14, 2024

@alchemi5t
add the following code

z = cropped_vertices[2, :].copy() / tform.params[0, 0]
cropped_vertices[2, :] = 1
vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices)
vertices = np.vstack((vertices[:2, :], z))
vertices = np.reshape(vertices.T, [256, 256, 3])

from pytorch-prnet.

alchemi5t avatar alchemi5t commented on June 14, 2024

@LucienXian I did; it keeps throwing index OOB for get_colors.

Traceback (most recent call last):
File "eval.py", line 202, in
main(parser.parse_args())
File "eval.py", line 185, in main
colors=get_colors(random_image,vertices)
File "eval.py", line 71, in get_colors
colors = image[0][ind[:,1], ind[:,0], :] # n x 3
IndexError: index -2147483648 is out of bounds for axis 0 with size 256

This the the makeshift script i am using.

random_image = [cv.resize(cv.imread("./test.jpg"),(256,256)).astype(np.float32)]
image_bchw = np.transpose(random_image, (0, 3, 1, 2))

image_tensor = torch.tensor(image_bchw)

torch_out = torch_model(image_tensor)
torch_out = np.transpose(
    torch_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_watched_out = torch_out[0]




cropped_pos = torch_out * 256 * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T

left = 0; right = 0; top = 256; bottom = 256
old_size = (right - left + bottom - top)/2
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.14])
size = int(old_size*1.58)

src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
DST_PTS = np.array([[0,0], [0,256 - 1], [256 - 1, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)

z = cropped_vertices[2, :].copy() / tform.params[0, 0]
cropped_vertices[2, :] = 1
vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices)
vertices = np.vstack((vertices[:2, :], z))
vertices = np.reshape(vertices.T, [256, 256, 3])
colors=get_colors(random_image,vertices)
write_obj_with_colors("test_Color", vertices, triangles, colors)

from pytorch-prnet.

liguohao96 avatar liguohao96 commented on June 14, 2024

@alchemi5t get_colors actually requires vertices to have shape (N, 3), but you make it (256, 256, 3)

from pytorch-prnet.

alchemi5t avatar alchemi5t commented on June 14, 2024

ah got it!!! The model is accurate now, but the color is mapped to a much larger image. Any hints on that?

global face_detector
torch_model = PRNet(3, 3)
torch_model.load_state_dict(torch.load('from_tf.pth'))
torch_model.eval()
torch_model=torch_model.double()
sys.path.append(args.prnet_dir)

bimage = [cv.resize(cv.imread("./image.jpg"),(256,256)).astype(np.float32)]
image_temp=cv.imread("./image.jpg")
imgc=cv.resize(cv.imread("./image.jpg"),(256,256)).astype(np.float32)/255
# imgc=image_temp.copy()/255
detected_faces = dlib_detect(image_temp)
       

d = detected_faces[0].rect ## only use the first detected face (assume that each input image only contains one face)
left = d.left(); right = d.right(); top = d.top(); bottom = d.bottom()
old_size = (right - left + bottom - top)/2
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.14])
size = int(old_size*1.58)

# crop image
src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
DST_PTS = np.array([[0,0], [0,256 - 1], [256 - 1, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)

image_temp = image_temp/255.
cropped_image = warp(image_temp, tform.inverse, output_shape=(256,256))

# run our net
#st = time()
bimage=np.array([cropped_image])
print(cropped_image.shape)
image_bchw = np.transpose(bimage, (0, 3, 1, 2))

image_tensor = torch.tensor(image_bchw)
# torch_watched_out = torch_model.input_conv(image_tensor)
# torch_watched_out = np.transpose(torch_watched_out.cpu().detach().numpy(), (0, 2, 3, 1))

torch_out = torch_model(image_tensor.double())
torch_out = np.transpose(
    torch_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_watched_out = torch_out[0]




cropped_pos = torch_out[0] * 256 * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T
z = cropped_vertices[2,:].copy()/tform.params[0,0]
cropped_vertices[2,:] = 1
vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices)
vertices = np.vstack((vertices[:2,:], z))
pos = np.reshape(vertices.T, [256, 256, 3]) 



vertices=get_vertices(pos)
# imgc=imgc/255
print(imgc.shape)
colors=get_colors(imgc, vertices)
write_obj_with_colors("./test.obj", vertices, triangles, colors)

from pytorch-prnet.

Related Issues (1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.