Comments (12)
@alchemi5t the pytorch output is UV position map
which is a special image with XYZ
stored inside pixels instead of RGB
. The pytorch version should produce the same result as tf version, so you could use the offical prnet utils with transfering pytorch tensor to numpy.
BTW, don't forget to multiply the output by 256x1.1
.
from pytorch-prnet.
@alchemi5t Yes, with the official prnet. When we use the pytorch version, the output is the same numpy array.
You can try it:
torch_input = torch.from_numpy(img_bchw)
torch_out = torch_model(torch_input).cpu().detach().numpy()
torch_out = np.transpose(torch_out, (0, 2, 3, 1)).squeeze()
cropped_pos = torch_out * resolution_op * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T
from pytorch-prnet.
TLDR: you can change the code in test/eval.py.
test/eval.py is designed to compare the output of tf and pytorch with a random input. It just create a random input and do inference in tf and pytorch to see if they produce the same output.
Change random input to image will build a fully functioning pytorch-prnet.
PLS, let me know if you need more detailed help.
from pytorch-prnet.
@liguohao96 Thank you for your quick reply!!! I have used your code to make inferences. And refer to https://github.com/YadiraF/PRNet for pre-processing of the image, but I still can't reconstruct a valid 3D object, get a valid .obj file like the original repo. Have you tried to reconstruct the face with the results of test/eval.py?
from pytorch-prnet.
I have solved the problem. Thanks a lot!
from pytorch-prnet.
@LucienXian how did you write the obj file out? from what i figured out, the pytorch output is the uv coords. did you use the official prnet utils to write the obj file?
@liguohao96 any help is appreciated.
from pytorch-prnet.
@liguohao96 The tf prnet utils requires vertices,colors, triangles and textures along with uv_coords. In the original tf project, they are getting these parameters from the prnet object. how would i get that in your pytorch version.
FYI, all i have managed to get is the torch_watched_out(uv_coords). I am trying to get textured models out.
from pytorch-prnet.
torch_out = torch_model(image_tensor)
torch_out = np.transpose(
torch_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_watched_out = torch_out
cropped_pos = torch_out * 256 * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T
vertices=get_vertices(cropped_pos)
colors=get_colors(random_image,vertices)
write_obj_with_colors("test_color", vertices, triangles, colors)
I created an obj with this pipeline, but the model is totally botched. Any idea what i might have done wrong?
@LucienXian @liguohao96 Also, Thank you for the prompt responses. This has been very helpful!
from pytorch-prnet.
@alchemi5t
add the following code
z = cropped_vertices[2, :].copy() / tform.params[0, 0]
cropped_vertices[2, :] = 1
vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices)
vertices = np.vstack((vertices[:2, :], z))
vertices = np.reshape(vertices.T, [256, 256, 3])
from pytorch-prnet.
@LucienXian I did; it keeps throwing index OOB for get_colors.
Traceback (most recent call last):
File "eval.py", line 202, in
main(parser.parse_args())
File "eval.py", line 185, in main
colors=get_colors(random_image,vertices)
File "eval.py", line 71, in get_colors
colors = image[0][ind[:,1], ind[:,0], :] # n x 3
IndexError: index -2147483648 is out of bounds for axis 0 with size 256
This the the makeshift script i am using.
random_image = [cv.resize(cv.imread("./test.jpg"),(256,256)).astype(np.float32)]
image_bchw = np.transpose(random_image, (0, 3, 1, 2))
image_tensor = torch.tensor(image_bchw)
torch_out = torch_model(image_tensor)
torch_out = np.transpose(
torch_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_watched_out = torch_out[0]
cropped_pos = torch_out * 256 * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T
left = 0; right = 0; top = 256; bottom = 256
old_size = (right - left + bottom - top)/2
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.14])
size = int(old_size*1.58)
src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
DST_PTS = np.array([[0,0], [0,256 - 1], [256 - 1, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)
z = cropped_vertices[2, :].copy() / tform.params[0, 0]
cropped_vertices[2, :] = 1
vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices)
vertices = np.vstack((vertices[:2, :], z))
vertices = np.reshape(vertices.T, [256, 256, 3])
colors=get_colors(random_image,vertices)
write_obj_with_colors("test_Color", vertices, triangles, colors)
from pytorch-prnet.
@alchemi5t get_colors
actually requires vertices to have shape (N, 3)
, but you make it (256, 256, 3)
from pytorch-prnet.
ah got it!!! The model is accurate now, but the color is mapped to a much larger image. Any hints on that?
global face_detector
torch_model = PRNet(3, 3)
torch_model.load_state_dict(torch.load('from_tf.pth'))
torch_model.eval()
torch_model=torch_model.double()
sys.path.append(args.prnet_dir)
bimage = [cv.resize(cv.imread("./image.jpg"),(256,256)).astype(np.float32)]
image_temp=cv.imread("./image.jpg")
imgc=cv.resize(cv.imread("./image.jpg"),(256,256)).astype(np.float32)/255
# imgc=image_temp.copy()/255
detected_faces = dlib_detect(image_temp)
d = detected_faces[0].rect ## only use the first detected face (assume that each input image only contains one face)
left = d.left(); right = d.right(); top = d.top(); bottom = d.bottom()
old_size = (right - left + bottom - top)/2
center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0 + old_size*0.14])
size = int(old_size*1.58)
# crop image
src_pts = np.array([[center[0]-size/2, center[1]-size/2], [center[0] - size/2, center[1]+size/2], [center[0]+size/2, center[1]-size/2]])
DST_PTS = np.array([[0,0], [0,256 - 1], [256 - 1, 0]])
tform = estimate_transform('similarity', src_pts, DST_PTS)
image_temp = image_temp/255.
cropped_image = warp(image_temp, tform.inverse, output_shape=(256,256))
# run our net
#st = time()
bimage=np.array([cropped_image])
print(cropped_image.shape)
image_bchw = np.transpose(bimage, (0, 3, 1, 2))
image_tensor = torch.tensor(image_bchw)
# torch_watched_out = torch_model.input_conv(image_tensor)
# torch_watched_out = np.transpose(torch_watched_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_out = torch_model(image_tensor.double())
torch_out = np.transpose(
torch_out.cpu().detach().numpy(), (0, 2, 3, 1))
torch_watched_out = torch_out[0]
cropped_pos = torch_out[0] * 256 * 1.1
cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T
z = cropped_vertices[2,:].copy()/tform.params[0,0]
cropped_vertices[2,:] = 1
vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices)
vertices = np.vstack((vertices[:2,:], z))
pos = np.reshape(vertices.T, [256, 256, 3])
vertices=get_vertices(pos)
# imgc=imgc/255
print(imgc.shape)
colors=get_colors(imgc, vertices)
write_obj_with_colors("./test.obj", vertices, triangles, colors)
from pytorch-prnet.
Related Issues (1)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-prnet.