xiaoxinhe / g-retriever Goto Github PK
View Code? Open in Web Editor NEWRepository for G-Retriever
Home Page: https://arxiv.org/abs/2402.07630
License: MIT License
Repository for G-Retriever
Home Page: https://arxiv.org/abs/2402.07630
License: MIT License
Thanks for your great work!
Howeve, I have some problems in reproducing experiments on the WebQSP dataset. I use the same command in the readme to process the data and train the model, but the test acc is only 0.605, which is much lower than 0.6732 ± 0.0076 reported in the paper.
My command is as follows:
python train.py --dataset webqsp --model_name graph_llm
The results are as as follows:
The val loss can train loss are as follows:
One possible reason for this could be related to cached graphs. The paper states that the average number of tokens in the WebQSP is 18, whereas in my cached graphs, it is only 8.1. Is there a potential issue with this difference?
original graphs
Hello, I am very interested in your work and I am using other knowledge graph datasets to run your project. May I ask what is the difference between the ExplaGraphs, SceneGraphs, WebQSP, and regular knowledge graphs you provide? Is my idea feasible? Looking forward to your reply
I just found there are only 74,942 graphs in train_sceneGraphs.json
and 10,696 graphs in val_sceneGraphs.json
, resulting total 85,638 graphs in scene_graphs but not 100,000 reported. Only train_sceneGraphs.json
is used in src.dataset.preprocess.scene_graphs
but split is for 100,000 samples.
Hello, thank you for the great work. I am attempting to fine-tune the G_retriever framework on a custom question-answering dataset using a GPU with 16GB of memory. However, I encounter out-of-memory errors during the training process.
Is there a recommended minimum GPU memory requirement for fine-tuning G_retriever on 14,3 MB dataset?
Are there additional optimization strategies or code modifications that could enable effective fine-tuning on a 16GB GPU?
Hi. Thanks for an interesting repo
I have a question regarding the usage of LLMs. Is the repo agnostic to any LLMs or only made for llama? More specifically, does the code work correctly if I use another LLM, such as Mixtral by replacing “llama2/llama2_7b_chat_hf” with “mistralai/Mixtral-8x7B-Instruct-v0.1”? If not, how can I change the code for using another LLM? Thanks.
您好,我在做比较类似的工作,我比较困惑的地方在与构建文本属性Graph时,使用的Embedding模型不同是否会造成不同的影响。我看到 对齐GNN 的输出和基本模型(llama-7B)的语义空间是使用了一个projection。我想问 如果当时对Graph 中的node 进行Embedding的时候,Embedding模型和之后的Base model 的语义空间,差的很大的话,会不会造成效果很差呢
When I do python -m src.dataset.webqsp
, I encounter the following error:
0%| | 0/4700 [00:00<?, ?it/s]
0%| | 1/4700 [00:00<56:06, 1.40it/s]
0%| | 2/4700 [00:00<32:35, 2.40it/s]
0%| | 3/4700 [00:01<21:58, 3.56it/s]
0%| | 3/4700 [00:01<28:47, 2.72it/s]
Traceback (most recent call last):
File "/hpc_stor03/sjtu_home/-/anaconda3/envs/gretriever/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/hpc_stor03/sjtu_home/-/anaconda3/envs/gretriever/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/hpc_stor/hpc_stor03/sjtu_home/-/remote/G-Retriever/src/dataset/webqsp.py", line 82, in
preprocess()
File "/hpc_stor/hpc_stor03/sjtu_home/-/remote/G-Retriever/src/dataset/webqsp.py", line 75, in preprocess
subg, desc = retrieval_via_pcst(graph, q_emb, nodes, edges, topk=3, topk_e=5, cost_e=0.5)
File "/hpc_stor/hpc_stor03/sjtu_home/-/remote/G-Retriever/src/dataset/utils/retrieval.py", line 21, in retrieval_via_pcst
_, topk_n_indices = torch.topk(n_prizes, topk, largest=True)
RuntimeError: selected index k out of range
Hi, Thanks for the great research.
Could I get a requirements.txt file that includes the versions of libraries like accelerate
and transformers
?
When using the latest versions of the libraries, there is an issue where the value of shift_labels
changes to random numbers in the code shift_labels = shift_labels.to(shift_logits.device)
within the LlamaForCausalLM
class.
I think this problem is caused by the device_map='auto'
part.
In the main results table 6.2, there are baseline results from Zero-shot, Zero-CoT, CoT-BAG, KAPING, and GraphToken. It is unclear exactly where these results came from. The numbers in this table are not represented in the respective papers, and it also appears that these strategies were not implemented or used in the run.sh script. Can you explain how these results were obtained?
Dear authors,
I meet exactly the same issue as "RuntimeError: selected index k out of range #10". Could the author double check this error by running the processing code from scratch instead of using the processed dataset?
Thanks~
Hi,
Am I misunderstanding something or there seems a bug in the highlighted code:
Line 111 in 565cd9d
Why is label_input_ids
part of input_ids
?
Your work is amazing!
Though I am facing issues reproducing experiments on WebQSP data. I am trying to replicate the project using READ.me in AWS Jupyter Notebook. The pre-processing of data is taking a lot of time. It shows the approximate time for pre-processing to be 1330 hours for graph encoding. Is it supposed to take this long?
Hello, thank you very much for your work. I have some questions that I hope you can answer.
Your paper says that the batch size used in the experiment is 4, but the batch size in the code is 8, and the data I ran with 8 is closer to what is shown in the paper, although there are still some shortcomings on the sence dataset.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.