Giter Club home page Giter Club logo

Comments (3)

Xiaoping777 avatar Xiaoping777 commented on May 13, 2024 1

finally I worked it out, here is the code for 1 case inference @tom-samsung

import numpy as np
import tensorflow.compat.v1 as tf
#To make tf 2.0 compatible with tf1.0 code, we disable the tf2.0 functionalities
tf.disable_eager_execution()

from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.summary import summary
from tensorflow.python.tools import saved_model_utils
from tensorflow.core.framework import graph_pb2 as gpb
from google.protobuf import text_format as pbtf

def extract_tensors(signature_def, graph):
    output = dict()

    for key in signature_def:
        value = signature_def[key]

        if isinstance(value, tf.TensorInfo):
            output[key] = graph.get_tensor_by_name(value.name)

    return output

def extract_input_name(signature_def, graph):
    input_tensors = extract_tensors(signature_def['serving_default'].inputs, graph)
    #Assuming one input in model.
    
    name_list = []
    for key in list(input_tensors.keys()):
        name_list.append(input_tensors.get(key).name)
   
    return name_list

def extract_output_name(signature_def, graph):
    output_tensors = extract_tensors(signature_def['serving_default'].outputs, graph)
    #Assuming one output in model.
    
    name_list = []
    for key in list(output_tensors.keys()):
        name_list.append(output_tensors.get(key).name)
        
    return name_list

def ass_input_dict(tensor_input_sample): 
    dict_input = {str(i+1)+":0" : [tensor_input_sample[i]] for i in range(len(tensor_input_sample))}
    return dict_input


checkpoint_path = "/tmp/run/tuner-1/160/saved_model/assets/"

with tf.Session(graph=tf.Graph()) as sess:
    serve = tf.saved_model.load(sess, tags=["serve"], export_dir=checkpoint_path)
    #print(type(model))  <class 'tensorflow.core.protobuf.meta_graph_pb2.MetaGraphDef'>
    
    #input_tensor_name = extract_input_name(serve.signature_def, sess.graph)
    output_tensor_name = extract_output_name(serve.signature_def, sess.graph)
    input_dict = ass_input_dict(sen_vec.detach().numpy())
    
    prediction = sess.run(output_tensor_name, feed_dict=input_dict)
    
print(prediction)

from model_search.

Xiaoping777 avatar Xiaoping777 commented on May 13, 2024

Hi tom, I just downloaded and re-installed the latest version, there is new folder generated with .pb file for each model, I think it might make things easier

from model_search.

tom-samsung avatar tom-samsung commented on May 13, 2024

Hey @Xiaoping777
thanks for the code. yes, I noticed that with a new version of repo and saved_models things are much easier now.
Unfortunately, I need to re-run everything but it's ok.
I'll try to wrap this up into keras lambda layer to have this additional option for people who have keras pipelines and post it somewhere. Maybe authors of this repo will update readme with all those information to make people lives easier before closing those issues:
#43
#39
Thanks again and happy model searching!

from model_search.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.