lavis-nlp / spert Goto Github PK
View Code? Open in Web Editor NEWPyTorch code for SpERT: Span-based Entity and Relation Transformer
License: MIT License
PyTorch code for SpERT: Span-based Entity and Relation Transformer
License: MIT License
Dear @markus-eberts,
Thanks for your work, that is a great job.
I am facing issue with output interpretation after Evaluator.eval_batch method and model output that are:
result = model(...)
entity_clf, rel_clf, rels = result
and
print(evaluator._pred_entities)
[[(1, 2, <spert.entities.EntityType at 0x15b817bf6d0>, 0.5499340891838074),
(2, 3, <spert.entities.EntityType at 0x15b817bf880>, 0.43182992935180664),
(3, 4, <spert.entities.EntityType at 0x15b817bf6d0>, 0.6427331566810608),
(4, 5, <spert.entities.EntityType at 0x15b817bf6d0>, 0.4554755389690399),..]
Could you explain the information on the different dimensions of the tensors entity_clf, rel_clf, rels
(what do each dimension correspond to) and how to read the second output (especially if it corresponds to raw tokens index or bert tokens index)
Thanks in advance !
Dear authors,
First of all thank you so much for your well done work. The code is really well written and I was able to reproduce the results of the paper and also work on my data.
I was wondering if there is the possibility of switching on a sort of early stopping mechanism, in order to stop the training only when the performance over the development set gets significantly worse than the one over the training set. Is there a particular input argument I should specify to allow for it or does it need to be implemented?
Thanks so much for your time and Best,
Hi,
Thanks for sharing such a wonderful work. I am studying the code and found a problem about the relation calssification(may not).
The format of the relationship classification label you construct is as follows:
[ [0,1,0,0], [0,0,1,0] ... ]
If a data set contains a total of four relationships, namely [A,B,C,D], then the first line [0,1,0,0] indicates that there is a relationship B between the corresponding entity pairs, the second line [0,0,1,0] indicates that there is a relationship C between the corresponding entity pairs.
There seems to be no doubt about the above, but the first and second rows may be the same entity pair sample. At this time, the label should be in the following form:
[ [0,1,1,0], ... ]
Finally, in the above two cases, the relation calssification loss is different, and the first is wrong.
Thanks for making a prediction mode.
I was just wondering where I can find the conll04_predictions.json file?
Thanks!
Chloe
The config file I use :
[1]
label = conll04_train
model_type = spert
#model_path = bert-base-cased
#tokenizer_path = bert-base-cased
model_path = data/models/conll04
tokenizer_path = data/models/conll04
train_path = data/datasets/conll04/conll04_train.json
valid_path = data/datasets/conll04/conll04_dev.json
types_path = data/datasets/conll04/conll04_types.json
train_batch_size = 2
eval_batch_size = 1
neg_entity_count = 100
neg_relation_count = 100
epochs = 20
lr = 5e-5
lr_warmup = 0.1
weight_decay = 0.01
max_grad_norm = 1.0
rel_filter_threshold = 0.4
size_embedding = 25
prop_drop = 0.1
max_span_size = 10
store_predictions = true
store_examples = true
sampling_processes = 4
sampling_limit = 100
max_pairs = 1000
final_eval = true
log_path = data/log/
save_path = data/save/
I only changed the model_path
and tokenizer_path
. When I ran python spert.py train --config configs/example_train.conf
it stuck at the dataloader(see comments in the code):
# in spert_trainer.py _train_epoch
for batch in tqdm(data_loader, total=total, desc='Train epoch %s' % epoch): ##### stuck at this line
model.train() # never reached here
batch = util.to_device(batch, self._device)
# forward step
entity_logits, rel_logits = model(encodings=batch['encodings'], context_masks=batch['context_masks'],
entity_masks=batch['entity_masks'], entity_sizes=batch['entity_sizes'],
relations=batch['rels'], rel_masks=batch['rel_masks'])
# compute loss and optimize parameters
batch_loss = compute_loss.compute(entity_logits=entity_logits, rel_logits=rel_logits,
rel_types=batch['rel_types'], entity_types=batch['entity_types'],
entity_sample_masks=batch['entity_sample_masks'],
rel_sample_masks=batch['rel_sample_masks'])
But when I changed the sampling_processes
to 0
, it worked, though slow.
Why did the execution stuck with sampling_processes=4
?
I am using a CPU to train the model if that matters.
should I use the folder of checkpoint to fit the "model_path" value in config files?
Process SpawnProcess-1:
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
result = (True, func(args, **kwds))
File "//spert/spert/sampling.py", line 253, in _produce_eval_batch
sample = _create_eval_sample(d, max_span_size, context_size)
File "/*/spert/spert/sampling.py", line 403, in _create_eval_sample
entity_masks = torch.stack(entity_masks)
RuntimeError: stack expects a non-empty TensorList
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(self._args, **self._kwargs)
File "/media/work/aads_nlp/spert/spert.py", line 12, in __train
types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader)
File "//spert/spert/spert_trainer.py", line 115, in train
self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)
File "//spert/spert/spert_trainer.py", line 231, in _eval
for batch in tqdm(sampler, total=total, desc='Evaluate epoch %s' % epoch):
File "//spert/env/lib/python3.6/site-packages/tqdm/std.py", line 1104, in iter
for obj in iterable:
File "/*/spert/spert/sampling.py", line 155, in next
batch, _ = self._results.next()
File "/usr/lib/python3.6/multiprocessing/pool.py", line 735, in next
raise value
RuntimeError: stack expects a non-empty TensorList
Thank you for all your work and publishing! really nice work and I love it.
I just wanted to ask how could I run prediction using this trained model (ADE, conll04 specific)? and how the input data should look like (format of the data)?
Thank you so much in advance :)
no need to reply.
the program report two performance of relation extraction, one is with NER , another is without NER, can you tell me what do they mean? Thanks
Hi,
I have a question about how to add weight in the loss function
for example
L = w1Ls + W2Lr
Thanks
hello, I'm reading the original codes, I find that the lengths of encodings is different on different samples.The bert model need a constant sentence length input,I want to know where there is a padding or truncating on enxodings
Hi,
May I ask why you use multiple sigmoids for relation classifier? Why not use a softmax layer as for the span classifier?
I have checked that there are no multiple relations between two given entities in all three datasets (only a very few duplicated relations in SciERC).
Best,
Enwei
Hello, I have been reading your code and trying to reproduce the results.
To my knowledge, we train on train set and evaluate on dev set to select hyperparameters.
After getting hyperparameters with best performance on dev set, we re-train the model on train+dev set and evaluate on test set.
However, I am not very sure about when to stop the final training. Do we use the number epoch where the model achieved best on dev set? Or directly choose the best score on test set?
Thanks!
I had everything working on Mac OS and I moved over the folder to Ubuntu. I get the following error when I try to runpython ./spert.py eval --config configs/example_eval.conf:
torch_shm_manager: error while loading shared libraries: libcudart.so.10.1: cannot open shared object file: No such file or directory
Process SpawnProcess-1:
Traceback (most recent call last):
File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/home/greg/spert/spert.py", line 23, in __eval
input_reader_cls=input_reader.JsonInputReader)
File "/home/greg/spert/spert/spert_trainer.py", line 165, in eval
self._eval(model, input_reader.get_dataset(dataset_label), input_reader)
File "/home/greg/spert/spert/spert_trainer.py", line 233, in _eval
for batch in tqdm(sampler, total=total, desc='Evaluate epoch %s' % epoch):
File "/home/greg/.local/lib/python3.6/site-packages/tqdm/std.py", line 1104, in iter
for obj in iterable:
File "/home/greg/spert/spert/sampling.py", line 155, in next
batch, _ = self._results.next()
File "/usr/lib/python3.6/multiprocessing/pool.py", line 735, in next
raise value
multiprocessing.pool.MaybeEncodingError: Error sending result: '(<spert.sampling.EvalTensorBatch object at 0x7f89949bfa20>, 0)'. Reason: 'RuntimeError('error executing torch_shm_manager at "/usr/local/lib/python3.6/dist-packages/torch/bin/torch_shm_manager" at /pytorch/torch/lib/libshm/core.cpp:99',)
When using store_examples = true, I get all the html output in the log directory. Whereas this is very nice for visualization, it is not so nice to use the results for other tasks or for further portions of a custom pipeline. Is there an option currently to get the list of entities and relations predicted with their start and end tokens (or character offsets) in a text file?
Dear @markus-eberts,
If possible, please provide your source code as I can reproduce the experimental results. Thank you very much.
I have a dataset without relation labels, but I would like to test your approach in spite of it (would you recommend?). However, an error occurs when training with no relation data, something related to 0-sized tensors. Have you also got this issue?
Does a pair of candidates have more than one relations?
I run the program on my dataset, got the error like as follows, according to the error, there is something wrong with the unzip function, so I print out all the random samples and unzip result, however, I cannot find any error, wonder if you have any hints on it? Thanks
random samples: [((5, 8), 1), ((3, 9), 4), ((8, 9), 1), ((4, 5), 1), ((1, 9), 6), ((2, 5), 3), ((2, 4), 2), ((4, 9), 3), ((2, 8), 4), ((3, 4), 1), ((2, 9), 5), ((2, 3), 1), ((1, 8), 5), ((5, 9), 2), ((3, 5), 2), ((1, 3), 2), ((1, 5), 4), ((3, 8), 3), ((1, 2), 1), ((1, 4), 3)]
unzip result ((5, 8), (3, 9), (8, 9), (4, 5), (1, 9), (2, 5), (2, 4), (4, 9), (2, 8), (3, 4), (2, 9), (2, 3), (1, 8), (5, 9), (3, 5), (1, 3), (1, 5), (3, 8), (1, 2), (1, 4)) (1, 4, 1, 1, 6, 3, 2, 3, 4, 1, 5, 1, 5, 2, 2, 2, 4, 3, 1, 3)
Train epoch 0: 0%|▎ | 13/3980 [00:04<22:55, 2.88it/s]/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
len(cache))
/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
len(cache))
/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
len(cache))
/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
len(cache))
Process SpawnProcess-1:
/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
len(cache))
multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
File "/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/pool.py", line 119, in worker
result = (True, func(*args, **kwds))
File "/lrlhps/users/c272987/spert/spert/sampling.py", line 237, in _produce_train_batch
sample = _create_train_sample(d, neg_entity_count, neg_rel_count, max_span_size, context_size)
File "/lrlhps/users/c272987/spert/spert/sampling.py", line 296, in _create_train_sample
neg_entity_spans, neg_entity_sizes = zip(*random_samples)
ValueError: not enough values to unpack (expected 2, got 0)
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
self.run()
File "/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/process.py", line 93, in run
self._target(*self._args, **self._kwargs)
File "/lrlhps/users/c272987/spert/spert.py", line 12, in __train
types_path=run_args.types_path, input_reader_cls=input_reader.JsonInputReader)
File "/lrlhps/users/c272987/spert/spert/spert_trainer.py", line 111, in train
input_reader.context_size, input_reader.relation_type_count)
File "/lrlhps/users/c272987/spert/spert/spert_trainer.py", line 182, in _train_epoch
for batch in tqdm(sampler, total=total, desc='Train epoch %s' % epoch):
File "/lrlhps/users/c272987/spert/env/lib/python3.6/site-packages/tqdm/_tqdm.py", line 955, in iter
for obj in iterable:
File "/lrlhps/users/c272987/spert/spert/sampling.py", line 155, in next
batch, _ = self._results.next()
File "/lrlhps/apps/python/python-3.6.5/lib/python3.6/multiprocessing/pool.py", line 735, in next
raise value
ValueError: not enough values to unpack (expected 2, got 0)
Hi, thanks for the amazing work
Is there a snippet somewhere explaining how to efficiently use the model for inference?
Like this:
s = "The sequence I wanna extract the relations from"
nlp(s)
>>> dict_of_relations(...)
Thanks a lot!
Hi,
What should I do to get the overlapping and non-overlapping results of the ADE dataset?
to be specifically, for example, let the sentence "The apple corporation" as input, is there any possible that the span ["The apple corporation"] is predicted as the label 'ORG' however, the span ["The apple"] is predicted as label the 'Other'? If so, what is the final prediction of input sentence "The apple corporation"?
by the way, in the training progress, after random sample, the max neg-span number is 100. But it is also seems much larger than the number of label span. For example, the sentence is made by 27 word, and the number of entities are 5, hence the label_entity_ span will be 5 whereas the neg_entity_span are 100. Will it be harmful for training the model?
in the checkpoint, there is no vocab file existed, which make evaluation fail when loading the checkpoint, sounds likely need to add the code of "tokenizer.save_pretrained" in the save model functionality.
Thanks
Min
Hi there, thanks a lot for sharing this awesome repo. I just have one question regarding the GPU usage. I have noticed that during evaluation it takes almost 2x GPU memory than the training process, even when I use the same batch size for both. I understand that during the evaluation, the system additionally performs the span selection process but is that the only thing responsible for such a huge memory space? Any comments on this would be really helpful. Thanks.
Dear authors,
I was wondering if you could provide a license for your project (or point me to that if there's already an existing one), to clarify if and under which conditions the code can be used and integrated in other projects. (https://choosealicense.com/no-permission/)
Thanks and Best,
For json file, I noticed that sounds likely all the sentences have entities and relations, wonder if we need to keep those sentences without relations or entities?
Thank you for publishing your work! I really enjoyed reading your paper and want to use your model on predicting on my dataset and was wondering if you can provide a script that pre-processes raw ADE data to these data you mentioned "ade_split_1_train.json" / "ade_split_1_test.json") under "data/datasets/ade". that goes into the ADE model
It would be great if you can provide what script you use to turn original ADE datasets to input data in ADE model
Thank you!
Hello, I want to perform experiments on my own dataset, but I cannot find the file 'vocab.txt'. How can I get this file?
In most of the papers, experiments are carried out on data sets ACE04/ACE05. Are there any corresponding experiments on ACE04/ACE05 with model spert? Thanks.
Hi @markus-eberts ,
really appreciated your genius work and dedicated effort on information extraction. It is really inspiring.
Recently, when I intended to try spert on my env, the training failed with last line below and did not throw any error:
loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin from cache at /root/.cache/torch/transformers/35d8b9d36faaf46728a0192d82bf7d00137490cd6074e8500778afed552a67e5.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2
Do you think it is because that the model could not be downloaded from aws? As a possible workaround I downloaded model file : bert-base-cased-pytorch_model.bin, and replaced "self.args.model_path" on
model = model_class.from_pretrained(
self.args.model_path,
.....)
But it did not work out. Do you know what is the cause?
Hi. Thanks for the paper and code. I had a couple questions:
Is there a simple way to do inference on a single sentence or a list of sentences? I looked over the evaluator code but it seems to only read in a full dataset at a time. Also do you have any advice on refactoring the model for realtime inference on sentences?
My second question is it seems that you already preprocessed that datasets. Do you have the code you used to originally put them in that format? I might have missed it but I didn't seem to see it your repository.
Thanks again.
I have a question, where can I get the entity_mask if I predict the relationship between entities from an unmarked sentence
First of all, thanks for sharing this cleaned and object-oriented code! I have learned a lot from this repo. I even want to say Wow, you can really code!
^_^
I have training the model on CoNLL04 dataset with the default configuration, according to the README, and the test results as follows:
--- Entities (NER) ---
type precision recall f1-score support
Org 79.43 83.84 81.57 198
Loc 91.51 90.87 91.19 427
Other 76.61 71.43 73.93 133
Peop 92.17 95.33 93.72 321
micro 87.70 88.51 88.10 1079
macro 84.93 85.37 85.10 1079
--- Relations ---
Without NER
type precision recall f1-score support
Kill 84.78 82.98 83.87 47
OrgBI 73.86 61.90 67.36 105
Work 61.54 63.16 62.34 76
LocIn 74.36 61.70 67.44 94
Live 74.04 77.00 75.49 100
micro 72.84 68.01 70.34 422
macro 73.72 69.35 71.30 422
With NER
type precision recall f1-score support
Kill 84.78 82.98 83.87 47
OrgBI 73.86 61.90 67.36 105
Work 61.54 63.16 62.34 76
LocIn 73.08 60.64 66.28 94
Live 74.04 77.00 75.49 100
micro 72.59 67.77 70.10 422
macro 73.46 69.14 71.07 422
The test result is worse than the original paper, especially for macro-average
metrics.
Is it possible that the random seed
is different? I just set seed=42
in example_train.conf
Thanks!
Hi, I have run your code on SCIERC and I can only achieve micro-f1 as 67.6. So could you show me how to tune your model to reproduce your result reported in paper on SCIERC? (Micro-F1 70.33)
I've seen ur paper.For one thing I don't really understand is how do u or what is ur algorithm select spans from a sentence and using as ur input to the span classifier?
And for relation classifier is it possible to convert into open relation extraction which relation r not pre-defined and exist int the sentence?
Hi @markus-eberts !
I have some questions regarding the _score method from evaluator.py:
for (sample_gt, sample_pred) in zip(gt, pred):
union = set()
union.update(sample_gt)
union.update(sample_pred)
for s in union:
if s in sample_gt:
t = s[2]
gt_flat.append(t.index)
types.add(t)
else:
gt_flat.append(0)
if s in sample_pred:
t = s[2]
pred_flat.append(t.index)
types.add(t)
else:
pred_flat.append(0)
Why exactly do you add the prediction and ground truth twice to the flat array in case the relation isn't classified correctly? Since you end up with an array larger than the actual number of relations in the evaluated dataset, wouldn't this penalyze the computed score, since you would have one correct element for every hit, and two wrong elements for every miss?
Thanks!
Hi @markus-eberts ,
thanks for sharing your great work.
I was playing around a variation of spERT, where the relations where extracted using a softmax instead of a sigmoid.
To ensure the correctness of the overall system I trained it with the version of conll04 that you provided with the model and everything seemed fine.
The issues arose when trying to train it with a different dataset, converted to a format compatible with spERT.
Train went smoothly, but the model didn't make any prediction at all, be it an entity or relation. I am for sure missing something, I was wondering if you could maybe provide to me a direction from which start to work.
Here is a single sample from the training dataset:
{"tokens": ["The", "role", "of", "p27(Kip1", ")", "in", "dasatinib-enhanced", "paclitaxel", "cytotoxicity", "in", "human", "ovarian", "cancer", "cells", ".", "\r\n"], "entities": [{"type": "drug", "start": 6, "end": 7}, {"type": "drug", "start": 7, "end": 8}], "relations": [{"type": "effect", "head": 0, "tail": 1}], "orig_id": "DDI-MedLine.d194.s0"}
On this dataset the softmax is recommended since all the relations are symmetrical and between two entities exists only a single relation.
Here is the log of the training run:
Config:
{'label': 'softmax_ddi', 'model_type': 'spert', 'model_path': 'bert-base-cased', 'tokenizer_path': 'bert-base-cased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'}
Repeat 1 timesIteration 0
2021-05-29 09:45:30,631 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json
2021-05-29 09:45:30,631 [MainThread ] [INFO ] Model type: spert
Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:09<00:00, 622.16it/s]
Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:01<00:00, 609.67it/s]
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relation type count: 5
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entity type count: 5
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Entities:
2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Entity=0
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug name=1
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Drug=2
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Group=3
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Brand=4
2021-05-29 09:45:41,928 [MainThread ] [INFO ] Relations:
2021-05-29 09:45:41,928 [MainThread ] [INFO ] No Relation=0
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Effect=1
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Int=2
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Mechanism=3
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Advise=4
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: train
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 6045
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 3378
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 12549
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Dataset: valid
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Document count: 931
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Relation count: 642
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Entity count: 2216
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates per epoch: 3022
2021-05-29 09:45:41,929 [MainThread ] [INFO ] Updates total: 15110[...]
Evaluation
--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctlytype precision recall f1-score support drug_n 0.00 0.00 0.00 101.0 drug 0.00 0.00 0.00 1396.0 group 0.00 0.00 0.00 538.0 brand 0.00 0.00 0.00 169.0 micro 0.00 0.00 0.00 2204.0 macro 0.00 0.00 0.00 2204.0
--- Relations ---
Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 int 0.00 0.00 0.00 8.0 effect 0.00 0.00 0.00 250.0 mechanism 0.00 0.00 0.00 253.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 int 0.00 0.00 0.00 8.0 effect 0.00 0.00 0.00 250.0 mechanism 0.00 0.00 0.00 253.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
2021-05-29 12:21:19,887 [MainThread ] [INFO ] Logged in: data/log/softmax_ddi/2021-05-29_09-45-29.875587
2021-05-29 12:21:19,887 [MainThread ] [INFO ] Saved in: data/save/softmax_ddi/2021-05-29_09-45-29.875587
The following are the major changes that I applied to the original model:
spert/spert_trainer.py
class SpERTTrainer(BaseTrainer):
config=config,
# SpERT model parameters
cls_token=self._tokenizer.convert_tokens_to_ids('[CLS]'),
- relation_types=input_reader.relation_type_count - 1,
+ relation_types=input_reader.relation_type_count,
entity_types=input_reader.entity_type_count,
max_pairs=self._args.max_pairs,
prop_drop=self._args.prop_drop,
class SpERTTrainer(BaseTrainer):
num_warmup_steps=args.lr_warmup * updates_total,
num_training_steps=updates_total)
# create loss function
- rel_criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
+ rel_criterion = torch.nn.CrossEntropyLoss(reduction='none')
entity_criterion = torch.nn.CrossEntropyLoss(reduction='none')
spert/loss.py
class SpERTLoss(Loss):
if rel_count.item() != 0:
rel_logits = rel_logits.view(-1, rel_logits.shape[-1])
- rel_types = rel_types.view(-1, rel_types.shape[-1])
+ rel_types = rel_types.view(-1)
rel_loss = self._rel_criterion(rel_logits, rel_types)
- rel_loss = rel_loss.sum(-1) / rel_loss.shape[-1]
rel_loss = (rel_loss * rel_sample_masks).sum() / rel_count
# joint loss
spert/sampling.py
def create_train_sample(doc, neg_entity_count: int, neg_rel_count: int, max_span
rel_sample_masks = torch.zeros([1], dtype=torch.bool)
# relation types to one-hot encoding
- rel_types_onehot = torch.zeros([rel_types.shape[0], rel_type_count], dtype=torch.float32)
- rel_types_onehot.scatter_(1, rel_types.unsqueeze(1), 1)
- rel_types_onehot = rel_types_onehot[:, 1:] # all zeros for 'none' relation
return dict(encodings=encodings, context_masks=context_masks, entity_masks=entity_masks,
entity_sizes=entity_sizes, entity_types=entity_types,
- rels=rels, rel_masks=rel_masks, rel_types=rel_types_onehot,
+ rels=rels, rel_masks=rel_masks, rel_types=rel_types,
entity_sample_masks=entity_sample_masks, rel_sample_masks=rel_sample_masks)
spert/models.py
class SpERT(BertPreTrainedModel):
chunk_rel_logits = self._classify_relations(entity_spans_pool, size_embeddings,
relations, rel_masks, h_large, i)
# apply sigmoid
- chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
- rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf
+ rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits
- rel_clf = rel_clf * rel_sample_masks # mask
# apply softmax
entity_clf = torch.softmax(entity_clf, dim=2)
+ rel_clf = torch.softmax(rel_clf, dim=2)
+ rel_clf *= rel_sample_masks
return entity_clf, rel_clf, relations
spert/predictions.py
def convert_predictions(batch_entity_clf: torch.tensor, batch_rel_clf: torch.ten
batch_entity_types *= batch['entity_sample_masks'].long()
# apply threshold to relations
- batch_rel_clf[batch_rel_clf < rel_filter_threshold] = 0
batch_pred_entities = []
batch_pred_relations = []
spert/predictions.py
def _convert_pred_relations(rel_clf: torch.tensor, rels: torch.tensor,
entity_types: torch.tensor, entity_spans: torch.tensor, input_reader: BaseInputReader):
- rel_class_count = rel_clf.shape[1]
- rel_clf = rel_clf.view(-1)
# get predicted relation labels and corresponding entity pairs
- rel_nonzero = rel_clf.nonzero().view(-1)
- pred_rel_scores = rel_clf[rel_nonzero]
-
- pred_rel_types = (rel_nonzero % rel_class_count) + 1 # model does not predict None class (+1)
- valid_rel_indices = rel_nonzero // rel_class_count
+ valid_rel_indices = torch.nonzero(torch.sum(rel_clf, dim=-1)).view(-1)
+ valid_rel_indices = valid_rel_indices.view(-1)
+
+ pred_rel_types = rel_clf[valid_rel_indices]
+ if pred_rel_types.shape[0] != 0:
+ pred_rel_types = pred_rel_types.argmax(dim=-1)
+ valid_rel_indices = torch.nonzero(pred_rel_types).view(-1)
+
+ pred_rel_types = pred_rel_types[valid_rel_indices]
+
+ pred_rel_scores = rel_clf[valid_rel_indices]
+ if pred_rel_scores.shape[0] != 0:
+ pred_rel_scores = pred_rel_scores.max(dim=-1)[0]
valid_rels = rels[valid_rel_indices]
Not related to the previous topic, thought I'd add it here since the same dataset is involved.
During the experimentation with the original spERT I changed bert to scibert. Using 1 epoch of training I had no issues whatsoever, when I increased them to 5 the procedure to store the predictions started to pick up relations that should instead be filtered out by previous elaboration (if I interpreted everything correctly).
Here is the log
Config:
{'label': 'scibert_ddi', 'model_type': 'spert', 'model_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'tokenizer_path': '/home/deeplearning/Salvalai/scibert_scivocab_uncased', 'train_path': 'data/datasets/unibs/train/all.json', 'valid_path': 'data/datasets/unibs/dev/all.json', 'types_path': 'data/datasets/unibs/types.json', 'train_batch_size': '2', 'eval_batch_size': '1', 'neg_entity_count': '100', 'neg_relation_count': '100', 'epochs': '5', 'lr': '5e-5', 'lr_warmup': '0.1', 'weight_decay': '0.01', 'max_grad_norm': '1.0', 'rel_filter_threshold': '0.4', 'size_embedding': '25', 'prop_drop': '0.1', 'max_span_size': '10', 'store_predictions': 'true', 'store_examples': 'true', 'sampling_processes': '4', 'max_pairs': '1000', 'final_eval': 'true', 'log_path': 'data/log/', 'save_path': 'data/save/'}
Repeat 1 timesIteration 0
2021-05-28 10:54:28,162 [MainThread ] [INFO ] Datasets: data/datasets/unibs/train/all.json, data/datasets/unibs/dev/all.json
2021-05-28 10:54:28,162 [MainThread ] [INFO ] Model type: spert
Parse dataset 'train': 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6045/6045 [00:12<00:00, 466.41it/s]
Parse dataset 'valid': 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 931/931 [00:02<00:00, 359.59it/s]
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Relation type count: 5
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entity type count: 5
2021-05-28 10:54:43,771 [MainThread ] [INFO ] Entities:
2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Entity=0
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug name=1
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Drug=2
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Group=3
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Brand=4
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relations:
2021-05-28 10:54:43,772 [MainThread ] [INFO ] No Relation=0
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Effect=1
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Int=2
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Mechanism=3
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Advise=4
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: train
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 6045
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Relation count: 3378
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Entity count: 12549
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Dataset: valid
2021-05-28 10:54:43,772 [MainThread ] [INFO ] Document count: 931
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Relation count: 642
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Entity count: 2216
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates per epoch: 3022
2021-05-28 10:54:43,773 [MainThread ] [INFO ] Updates total: 15110[...]
Evaluation
--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctlytype precision recall f1-score support brand 0.00 0.00 0.00 169.0 drug 0.00 0.00 0.00 1396.0 drug_n 0.00 0.00 0.00 101.0 group 0.00 0.00 0.00 538.0 micro 0.00 0.00 0.00 2204.0 macro 0.00 0.00 0.00 2204.0
--- Relations ---
Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 mechanism 0.00 0.00 0.00 253.0 effect 0.00 0.00 0.00 250.0 int 0.00 0.00 0.00 8.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)type precision recall f1-score support advise 0.00 0.00 0.00 130.0 mechanism 0.00 0.00 0.00 253.0 effect 0.00 0.00 0.00 250.0 int 0.00 0.00 0.00 8.0 micro 0.00 0.00 0.00 641.0 macro 0.00 0.00 0.00 641.0
Process SpawnProcess-1:
Traceback (most recent call last):
File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/deeplearning/.conda/envs/salvalai/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/deeplearning/Salvalai/spert/spert.py", line 16, in __train
trainer.train(train_path=run_args.train_path, valid_path=run_args.valid_path,
File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 97, in train
self._eval(model, validation_dataset, input_reader, epoch + 1, updates_epoch)
File "/home/deeplearning/Salvalai/spert/spert/spert_trainer.py", line 253, in _eval
evaluator.store_predictions()
File "/home/deeplearning/Salvalai/spert/spert/evaluator.py", line 87, in store_predictions
prediction.store_predictions(self._dataset.documents, self._pred_entities,
File "/home/deeplearning/Salvalai/spert/spert/prediction.py", line 196, in store_predictions
head_idx = converted_entities.index(converted_head)
ValueError: {'type': 'None', 'start': 0, 'end': 1} is not in list
Best regards
Hi there!
I'm trying to evaluate TACRED dataset on SPERT, but I'm getting extremely low results with this dataset. I wrote a script to convert the original TACRED json to the format you use in your training. Here's an example:
TACRED original:
{'id': "e7798fb926b9403cfcd2",
'docid': "APW_ENG_20101103.0539",
'relation': "per:title",
'token': ["At", "the", "same", "time", ",", "Chief", "Financial", "Officer", "Douglas", "Flint", "will",
"become", "chairman", ",", "succeeding", "Stephen", "Green", "who", "is", "leaving", "to", "take",
"a", "government", "job", "." ],
'subj_start': 8,
'subj_end': 9,
'obj_start': 12,
'obj_end': 12,
'subj_type': "PERSON",
'obj_type': "TITLE"}
TACRED converted:
{
"tokens":[ "At", "the", "same", "time", ",", "Chief", "Financial", "Officer", "Douglas", "Flint", "will", "become", "chairman", ",", "succeeding", "Stephen", "Green", "who", "is", "leaving", "to", "take", "a", "government", "job", "." ],
"entities":[
{
"type":"PERSON",
"start":8,
"end":10
},
{
"type":"TITLE",
"start":12,
"end":13
}
],
"relations":[
{
"type":"per:title",
"head":0,
"tail":1
}
],
"orig_id":"e7798fb926b9403cfcd2"
}
I'm getting these results below, using the same config as the conll04 sample, changing only the dataset. Any idea on why the results are so bad? Should I adapt the code somehow?
Thanks!
--- Entities (named entity recognition (NER)) ---
An entity is considered correct if the entity type and span is predicted correctly
type precision recall f1-score support
Loc 18.80 13.44 15.68 677
Tit 17.92 11.93 14.33 1701
url 71.03 79.17 74.88 96
Dat 24.85 28.95 26.74 3064
Crim 22.87 22.05 22.45 195
SoP 22.65 28.54 25.26 431
Cntr 23.39 27.13 25.12 1434
Cit 19.15 21.45 20.24 951
Per 44.15 57.61 49.99 20644
Misc 16.45 14.67 15.51 600
Relig 15.54 19.11 17.14 157
Ideo 11.54 6.12 8.00 49
Dur 25.63 22.56 24.00 359
Org 49.10 50.64 49.86 12272
Num 23.73 29.22 26.19 1742
Nation 19.68 12.53 15.31 495
CoD 20.16 26.33 22.83 395
micro 40.08 46.40 43.01 45262
macro 26.27 27.73 26.68 45262
--- Relations ---
Without named entity classification (NEC)
A relation is considered correct if the relation type and the spans of the two related entities are predicted correctly (entity type is not considered)
/home/pedro/anaconda3/envs/fast-bert/lib/python3.7/site-packages/sklearn/metrics/classification.py:1437: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.
'precision', 'predicted', average, warn_for)
type precision recall f1-score support
Cause of Death 6.86 17.86 9.92 168
Country of Headquarters 10.05 11.30 10.64 177
Website 56.25 62.79 59.34 86
Date of Death 13.13 6.31 8.52 206
Charges 5.00 4.76 4.88 105
Countries of Residence 6.96 3.54 4.69 226
Country of Birth 7.50 15.00 10.00 20
Origin 5.80 1.90 2.87 210
Member of 0.00 0.00 0.00 31
Top Members / Employees 14.23 32.21 19.74 534
Person Alternate Names 15.38 5.26 7.84 38
State or Province of Birth 0.00 0.00 0.00 26
State or Provinces of Residence 7.03 12.50 9.00 72
Country of Death 0.00 0.00 0.00 46
State or Province of Headquarters 11.86 30.00 17.00 70
Other Family 1.87 2.50 2.14 80
Person Parents 4.55 5.36 4.92 56
Founded by 25.00 9.21 13.46 76
Age 13.14 18.93 15.51 243
Religion 4.11 5.66 4.76 53
Children 6.12 9.09 7.32 99
Title 7.63 7.40 7.51 919
Dissolved 0.00 0.00 0.00 8
Organization Parents 2.70 7.29 3.94 96
Political/Religious affiliation 4.55 20.00 7.41 10
No Relation 9.53 19.72 12.85 17195
Subsidiaries 4.90 4.42 4.65 113
Employee Of 2.07 1.87 1.96 375
State or Province of Death 19.05 9.76 12.90 41
Siblings 6.02 16.67 8.85 30
Shareholders 6.06 3.64 4.55 55
Cities of Residence 7.94 5.59 6.56 179
City of Death 4.11 2.54 3.14 118
City of Headquarters 4.64 6.42 5.38 109
Schools Attended 12.82 10.00 11.24 50
Date of Birth 12.24 19.35 15.00 31
Founded 11.63 13.16 12.35 38
Members 0.00 0.00 0.00 85
Spouse 7.76 10.69 8.99 159
City of Birth 7.69 9.09 8.33 33
Organization Alternate Names 22.30 18.93 20.48 338
Number of Employees/Members 13.21 25.93 17.50 27
micro 9.53 17.80 12.42 22631
macro 9.09 11.11 9.19 22631
With named entity classification (NEC)
A relation is considered correct if the relation type and the two related entities are predicted correctly (in span and entity type)
type precision recall f1-score support
Cause of Death 6.86 17.86 9.92 168
Country of Headquarters 10.05 11.30 10.64 177
Website 56.25 62.79 59.34 86
Date of Death 13.13 6.31 8.52 206
Charges 5.00 4.76 4.88 105
Countries of Residence 5.22 2.65 3.52 226
Country of Birth 7.50 15.00 10.00 20
Origin 5.80 1.90 2.87 210
Member of 0.00 0.00 0.00 31
Top Members / Employees 13.98 31.65 19.39 534
Person Alternate Names 15.38 5.26 7.84 38
State or Province of Birth 0.00 0.00 0.00 26
State or Provinces of Residence 7.03 12.50 9.00 72
Country of Death 0.00 0.00 0.00 46
State or Province of Headquarters 11.86 30.00 17.00 70
Other Family 1.87 2.50 2.14 80
Person Parents 4.55 5.36 4.92 56
Founded by 25.00 9.21 13.46 76
Age 13.14 18.93 15.51 243
Religion 4.11 5.66 4.76 53
Children 6.12 9.09 7.32 99
Title 7.63 7.40 7.51 919
Dissolved 0.00 0.00 0.00 8
Organization Parents 2.70 7.29 3.94 96
Political/Religious affiliation 4.55 20.00 7.41 10
No Relation 9.01 18.64 12.15 17195
Subsidiaries 3.92 3.54 3.72 113
Employee Of 2.07 1.87 1.96 375
State or Province of Death 19.05 9.76 12.90 41
Siblings 6.02 16.67 8.85 30
Shareholders 6.06 3.64 4.55 55
Cities of Residence 7.14 5.03 5.90 179
City of Death 4.11 2.54 3.14 118
City of Headquarters 4.64 6.42 5.38 109
Schools Attended 12.82 10.00 11.24 50
Date of Birth 12.24 19.35 15.00 31
Founded 11.63 13.16 12.35 38
Members 0.00 0.00 0.00 85
Spouse 7.76 10.69 8.99 159
City of Birth 5.13 6.06 5.56 33
Organization Alternate Names 22.30 18.93 20.48 338
Number of Employees/Members 13.21 25.93 17.50 27
micro 9.07 16.95 11.82 22631
macro 8.93 10.94 9.04 22631
Hi, Thanks for sharing such excellent work. After reading the paper and some issue (#2 and #14 ), I still have some doubts, and look forward to your answers!
According to the issue #2 and #14, The training and testing process in this work I understand is as follows:
If the above understanding is consistent with the author’s operation (True), the following issues will be issues. If not (False), please explain the whole process (training and testing) in detail.
The above process seems unfair and incorrect, and the validation dataset did not play its role. Of course, if the validation dataset (dev.json) is add to the train dataset (train.json) to train the model together (as the operation in this work I understand), the finally model should be better, especially when the training set (train.json) is relatively small. After all, deep learning is data-driven.
First of all, thanks for the great implementation!
I've been trying to use different transformers (e.g. RoBERTa) with SpERT, but I ran into some problems. Simply changing the model_path
and tokenizer_path
in the config to the name of a different transformer from https://huggingface.co/models does not work, since the code currently uses BERT-specific classes such as BertTokenizer
and BertModel
instead of AutoTokenizer
and AutoModel
.
But even if I change those, I still have problems, possibly because the SpERT class itself is derived from BertPreTrainedModel
. Using SciBERT by following your instructions definitely works, but I think that's because SciBERT has the exact same architecture/layer names/etc.
Do you know if using other transformers are possible with the current implementation (I might have misunderstood some parts)? If not, do you know what modifications would be needed to make it work with other transformers, or if there are any workarounds I could use? Thanks in advance!
@markus-eberts Hi Markus,
After training, I checked path data/save/conll04_train/2021-01-07_09:31:34.998954/final_model/pytorch_model.bin and realized that the the model is still original pretrained bert model.
I am not an expert for pytorch, does the below pytorch mean that only the pretrained bert model is saved?
# save model
if isinstance(model, DataParallel):
model.module.save_pretrained(dir_path)
else:
model.save_pretrained(dir_path)
Thanks for your good work.
I want to know the format of input of predict?
It's like this: (1)
Jack is born in London.
or like this: (2)
where O ['N']
the O ['N']
government O ['N']
soldiers O ['N']
with O ['N']
the O ['N']
former O ['N']
Northern B-Org ['N']
Improvement I-Org ['N'] [
Brigade I-Org
( O ['N']
Brigada B-Org ['N']
de I-Org ['N']
Melhoramentos I-Org ['N']
do I-Org ['N']
Norte I-Org
) O ['N']
, O ['N']
I can't access the website, could you please give me another way to get it?
hi @markus-eberts ,
I was just wondering where I can get the original dataset of Conll04, the link is not mentioned in the original paper
when i run the code by using
python ./spert.py eval --config configs/example_eval.conf
it shows:
Traceback (most recent call last):
File "./spert.py", line 39, in
_eval()
File "./spert.py", line 28, in _eval
process_configs(target=__eval, arg_parser=arg_parser)
File "D:\season\MyCode\spert\config_reader.py", line 7, in process_configs
ctx = mp.get_context('fork')
File "D:\install\conda\envs\season\lib\multiprocessing\context.py", line 238, in get_context
return super().get_context(method)
File "D:\install\conda\envs\season\lib\multiprocessing\context.py", line 192, in get_context
raise ValueError('cannot find context for %r' % method) from None
ValueError: cannot find context for 'fork'
i have already download the datasets and model into /data
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.