Giter Club home page Giter Club logo

Comments (10)

williamleif avatar williamleif commented on August 20, 2024 2

Hi @doannam020293,

Thanks for the question! Could you describe your use case/constraints in a bit more detail? Right now, we simulate adding new nodes after training by marking some nodes with the "test" and "val" attribute in the input networkx graph. However, this does assume that the "new" nodes are already there and only simulates a truly inductive setting.

If you want to use a trained model on a truly new graph, you will need to modify the adj_info variables/constants (e.g., seen here: https://github.com/williamleif/GraphSAGE/blob/master/graphsage/supervised_train.py#L259) and you will need to update the feature matrix as well (seen here: https://github.com/williamleif/GraphSAGE/blob/master/graphsage/supervised_train.py#L125). How exactly you do this will depend on your exact setting/constraints (e.g., you might want to make the adj_info and feature tensors a placeholders that are fed by feed_dict, though this will hurt data I/O performance).

Cheers,
Will

from graphsage.

erica-lee avatar erica-lee commented on August 20, 2024

Hi,
I want to use a trained model on a new node that is unseen node in previous graph. The model trained by unsupervised_train.py

Could you check my process?
1. train model 2. add new node to graph 3. retrain the model on new graph

If it is OK, then I have question for more details in code.

  1. Should I modify code?
    adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
    I think It would be ok. Because the 'minibatch' get adj_info for new graph, right...?
  2. When I load model, could I use just tensor of model like below?
class MODEL():
    def __init__(self, graph):
        self.loss = graph.get_tensor_by_name('div_1:0')
        self.mrr = graph.get_tensor_by_name('Mean:0')
        self.outputs1 = graph.get_tensor_by_name('l2_normalize:0')
        self.opt_op = graph.get_operation_by_name('Adam')
        self.ranks = graph.get_tensor_by_name('TopKV2_1:1')
        self.aff_all = graph.get_tensor_by_name('concat:0')
        
def reconstruct_placeholders(graph):
    return {
        'batch1': graph.get_tensor_by_name("batch1:0"),
        'batch2': graph.get_tensor_by_name("batch2:0"),
        'neg_samples': graph.get_tensor_by_name("neg_sample_size:0"),
        'dropout': graph.get_tensor_by_name("dropout:0"),
        'batch_size': graph.get_tensor_by_name("batch_size:0")
    }

from graphsage.

RexYing avatar RexYing commented on August 20, 2024

Hi I can't really tell if your snippet makes sense, since I don't know the context, but tf.get_variable seems to be the more common way of retrieving the variables imo.
You brought out an important point though: when training on datasets with multiple graphs you have 2 options: 1. merge all graphs in to a giant disjoint graph, which is easier, but less elegant; 2. make sure that the relevant pieces of adjacency matrices for a minibatch of graphs get fed into placeholders, which involves a bit more engineering effort.

from graphsage.

anny0316 avatar anny0316 commented on August 20, 2024

I want to know this, how to embed a new node using trained model?

from graphsage.

shivam1702 avatar shivam1702 commented on August 20, 2024

How to embed a new subgraph?
For example, how to create the embeddings of a number of new nodes that are being added to the graph?
BY providing all of them along with their feature vectors as a minibatch to the unsupervised aggregator function?

from graphsage.

sam-lev avatar sam-lev commented on August 20, 2024

I have a similar question as well. To preface I would like to use graphsage for inductive learning. Given a trained model with a learned graph embedding and aggregator functions, how would someone take an unseen graph (one not implicitly included as a disjoint graph in the graph space used for training) and embed this new, unlabeled graph in order to inference over the new graph's nodes using the aggregators learned during training?

I have seen that the stellargraph library is said to do this but looking through the code and documentation it seems the adjacency neighborhoods used for inference rely on the graph being classified to be a subgraph of the original graph trained over, e.g. a disjoint graph embedded during the training phase. Or is it that in order to perform inference on the new graph it is necessary to add the new graph as a disjoint graph, perform a forward pass of training with the learned model, and then classify the unlabeled nodes?

Any clarity or information would be greatly appreciated, thanks :)

from graphsage.

shivam1702 avatar shivam1702 commented on August 20, 2024

Hi @sam-lev , if by disjoint graph, you mean vertex disjoint graph, which means that any nodes in this new graph are not same as any node in the previous graph, then I have two thoughts about it:

  1. The aggregator functions and the trained unsupervised model might work on it, but that will depend whether the feature space for these new nodes is the same as the training graph used, and also the node connectivity should be coherent with those in training graph.
  2. To compute the embeddings for this new graph, we will have to manually create data pipeline very similar to a Minibatch used in training (or the one used in save_val_embeddings function.
    For this, (similar posts #102 , etc. ) with help from issues #65 , #81, #22, #29 , #34 ,
    the basic steps needed to be performed are:
a. Create a new adjacency matrix or the variable adj_info in the code, which contains information about the new node and it's connectivity with the rest of the graph. (hence requiring to change the minibatch to accomodate new nodes)
b. Modify feature matrix accordingly, 
c. Run save_val_embeddings function 

Also, #63 suggests it's better to aggregate the entire neighbourhood instead of random sampling with replacement, as using during training time, in order to get more deterministic results.

Will update with more clear understanding

from graphsage.

sam-lev avatar sam-lev commented on August 20, 2024

Hi @shivam1702, your input was very helpful, thank you. After implementing your advice I'm now running into some confusion. Currently, I save the trained graphsage model by checkpoints with tensorflow to be called on later, or after training, I use the trained model and same tensorflow session for inference over a new unseen graph. I have implemented a conversion of new inference graphs into minibatches with NodeMinibatchIterator using the new unseen graph and feature information. Using this to define adj_info I then feed the model and the new minibatch into incremental_evaluate.

My issue here is that the trained model does not accommodate the new number of nodes and adjacencies. i.e. the error I get is
" Invalid argument: You must feed a value for placeholder tensor 'Placeholder' with dtype int32 and shape [11334,192] "
where the expected dimensionality is the dimensionality of the training graph.

As a solution, I redefine the node sampler with UniformNeighborSampler() using the new inference graph and similarly redefine the 'layer_infos'. I then attempt to update the trained model with the new layer_infos to correctly handle the new inference graph. Unfortunately, doing this I still run into issues of placehodlers and variables learned during training being uninitialized or inapplicable.

My question is how/what exactly within the trained model should I update given the new, unseen inference graph while preserving the learned aggregators and weights? I had the impression that given a trained model I should be able to just feed a new minibatch for the unseen graph using incremental_evaluate but unfortunately, things seem to need a little more finesse. Secondly, which might amount to a tensorflow question, how am I to initialize the learned model and then update the layer information?

I appreciate any help you or any others can offer.

Best

from graphsage.

shivam1702 avatar shivam1702 commented on August 20, 2024

Unfortunately, doing this I still run into issues of placehodlers and variables learned during training being uninitialized or inapplicable.

This is exactly the problem I am facing, with the new nodes unable to either be accommodated in the already created Sampler Placeholders.
Do let me know what all you tried, what failed and what worked.

from graphsage.

sam-lev avatar sam-lev commented on August 20, 2024

Hi @shivam1702, it's good to hear from you and I appreciate the opportunity to discuss this. Currently, I have created a new inference method that can be called on after training with a saved graph model learned during training or subsequent training in the same process. For inference, I am recycling much of the code from the training method but updating certain attributes of the supervised model to accommodate the unseen graph which I hope to use the previously trained model for prediction. The main changes apply to:
- the adj_info and minibatch as you previously noted.
- UniformNeighborSampler(adj_info)
- layer_infos with the new neighborhood sampler.
- updating the supervised model with the new adjacency information for inference. For this, similar to constructing the training supervised model, the inference model updates the samples and instance variables of the 'supervised_models' which can be found in the instantiation of the supervised model as well as the build method. All this said these changes just amount to allowing the supervised model graph to match the new batch size and shape of inference nodes from the new graph which I hope to infer on using the previously trained supervised model.
- Changing how to initialize the adjacency matrix by doing:
sess.run(adj_info.initializer, feed_dict={adj_info_ph: minibatch.adj})
In place of:
sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) as done in the train method. This prevents the trained model variables from being reinitialized however as a result each placeholder needs to be initialized and given a value. This is where I am now having the most issue.

I have also attempted to use the trained model and pass a new adjacency matrix but run into errors with the session's graph due to the batch size and adjacencies being of a different shape. Similarly, during training, I have tried saving the best model to later call and use for inference but run into similar problems with the model graph's dimensions.

Much of the error's I am running into are due to the placeholder variables needing to be initialized along with knowing which graph variables to re-initialize while preserving the trained weights of the net. To address this I wrote a new method inference_feed_dict which works similar to batch_feed_dict in minibatch.py. The inference_feed_dict is used to update the placeholder variables (batch_size, batch and labels) and retruns them to the inference method. The inference method then assigns the placeholders to a variable, for example:
inf_feed_dict, batch1, labels = minibatch.inference_feed_dict() b = tf.Variable(placeholders['batch'], trainable=False, name='b') bs = tf.Variable(placeholders['batch_size'], trainable=False, name='bs') batch_assign = tf.assign(b, batch1, name='batch') batch_size_assign = tf.assign(bs, FLAGS.batch_size, name='batch_size')
and before using incremental_evaluate to infer over the new graph I initialize the variables with:
sess.run(batch_assign.op) sess.run(batch_size_assign.op)
My issue here is providing the right values and correctly initializing so despite my best efforts I keep running into errors such as:
"tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: You must feed a value for placeholder tensor 'batch_size' with dtype int32
[[node batch_size (defined at /home/sam/Documents/PhD/Research/GeoMSCxML/topoml/topoml/graphsage/gnn.py:1282) ]]
[[uniformneighborsampler_1/transpose/_153]]
(1) Invalid argument: You must feed a value for placeholder tensor 'batch_size' with dtype int32
[[node batch_size (defined at /home/sam/Documents/PhD/Research/GeoMSCxML/topoml/topoml/graphsage/gnn.py:1282) ]]
0 successful operations.
0 derived errors ignored."

and similarly with the other variables in the placeholders dictionary.

I look forward to hearing from you and would be grateful for any response otherwise from those who may have some incite.

from graphsage.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.