Giter Club home page Giter Club logo

Comments (6)

JasonLLu avatar JasonLLu commented on August 11, 2024 1

Thanks. It does seem that the issue might be that the model doesn't capture the connections. Then, what exactly does the model attempt to capture? Does this mean that this model is not able to support finding the closest relations given two entities? If my goal is to do that, what can be done?

from pytorch-biggraph.

lw avatar lw commented on August 11, 2024 1

I'm closing this as I feel that we gave all the answers that we could that pertain to "technical support" of the current PBG code. All remaining questions are research-related and I don't think we're going to investigate them inside GitHub issues... ;)

from pytorch-biggraph.

lw avatar lw commented on August 11, 2024

This idea (of querying for the most likely relation between two entities by doing a nearest neighbor search of the difference of the entities' embeddings in the space of relation parameters) has been floated around before (possibly by you, I don't remember...). We never tried it, so it's "unsupported" and I'm not sure it's supposed to work. Let me try to run the math on this specific instance.

If we look at the TransE model, which is probably what you have in mind, it tries to achieve equation (s = source, t = target, r = relation), which is equivalent to equation, which is the property that enables you to do the nearest neighbor search that you wanted. (Trying to achieve that the above value is zero doesn't mean this will always be the case, so there may anyway be some noise that interferes with your inference).

The config used to produce the Wikidata embeddings is here. We see that, like in TransE, we used the translation operator. However, we also used the dot product comparator and the softmax loss. This means that the model tries to make equation as high as possible compared to other candidate target nodes. This is a quite different property than the one for TransE, because it uses the dot product and because it just tries to maximize it, instead of trying to have it match a certain value.

Although intuitively embeddings that are close in the L2 norm will have a higher dot product, this is not an equivalence, so the property you're looking for in order to do the NN search does not necessarily hold when using the dot product. This may explain why you're not getting the results you're expecting.

from pytorch-biggraph.

adamlerer avatar adamlerer commented on August 11, 2024

@lerks is right, and in the long run I think you will need to train with a different objective function that directly minimizes $d(\theta_r, f(\theta_s, \theta_d))$ in some distance metrics. But I expect that the way you did it should work reasonably well. How do you know that the results are bad?

from pytorch-biggraph.

lw avatar lw commented on August 11, 2024

As an experiment I tried to see which (source, ?, target) triple had the highest score over all the possible ?, using the correct score calculation method. This is somewhat similar to using a flat index in FAISS, except that it works with dot product too. Here is my code (which I ran after copy-pasting your example):

In [3]: emb1 = Constants.ENTITIES[id_to_index[key1]]

In [4]: emb2 = Constants.ENTITIES[id_to_index[key2]]

In [5]: scores = emb1.reshape((1, 200)) @ (Constants.RELATIONS + emb2).transpose()

In [6]: np.argpartition(scores.flatten(), -5)[-5:]
Out[6]: array([4839, 4819, 5323, 7439, 7698])

In [7]: index_to_relation[7698]
Out[7]: '<http://www.wikidata.org/prop/direct/P2187>_reverse_relation'

In [8]: index_to_relation[7439]
Out[8]: '<http://www.wikidata.org/prop/direct/P5332>_reverse_relation'

In [9]: index_to_relation[5323]
Out[9]: '<http://www.wikidata.org/prop/direct/P1555>_reverse_relation'

In [10]: index_to_relation[4819]
Out[10]: '<http://www.wikidata.org/prop/direct/P1971>_reverse_relation'

In [11]: index_to_relation[4839]
Out[11]: '<http://www.wikidata.org/prop/direct/P1050>_reverse_relation'

This also doesn't seem to give meaningful results. Perhaps the issue is simply that the model isn't fine-tuned enough to capture this weak connection.

PS: It could also just be that my code is buggy...

from pytorch-biggraph.

lw avatar lw commented on August 11, 2024

The Wikidata embeddings were trained with a ranking loss, so the training process tried to make sure that a given (source, rel_type, target) triple had higher score than all other (source, rel_type, ?) triples. If we look into that for your example this is what we get (where http://www.wikidata.org/prop/direct/P1365 is the "replaces" relation, which is the only edge that connects the two entities you used in your example):

In [4]: rel = id_to_index['<http://www.wikidata.org/prop/direct/P1365>']

In [5]: emb1 = Constants.ENTITIES[id_to_index[key1]]

In [6]: scores = emb1.reshape((1, 200)) @ (Constants.ENTITIES + rel).transpose()

In [7]: ranks = np.argsort(np.argsort(scores.flatten()))

In [8]: ranks[id_to_index[key2]]
Out[8]: 78404878

Since we have 78404883 entities in total, this means that this edge is ranked as the 5th most likely edge between these two entities. This seems a pretty good result to me.

So, to answer your other question, to get a good result on relation type retrieval, you should probably implement and use a loss function that ranks according to that. To be honest, I am not sure this can be done properly without other changes to the training and negative sampling process. This is because, as I mentioned, this isn't a use case we really had in mind when designing the system.

from pytorch-biggraph.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.