Giter Club home page Giter Club logo

punctuation-restoration's Introduction

Punctuation Restoration using Transformer Models

This repository contins official implementation of the paper Punctuation Restoration using Transformer Models for High-and Low-Resource Languages accepted at the EMNLP workshop W-NUT 2020.

Data

English

English datasets are provided in data/en directory. These are collected from here.

Bangla

Bangla datasets are provided in data/bn directory.

Model Architecture

We fine-tune a Transformer architecture based language model (e.g., BERT) for the punctuation restoration task. Transformer encoder is followed by a bidirectional LSTM and linear layer that predicts target punctuation token at each sequence position.

Dependencies

Install PyTorch following instructions from PyTorch website. Remaining dependencies can be installed with the following command

pip install -r requirements.txt

Training

To train punctuation restoration model with optimal parameter settings for English run the following command

python src/train.py --cuda=True --pretrained-model=roberta-large --freeze-bert=False --lstm-dim=-1 
--language=english --seed=1 --lr=5e-6 --epoch=10 --use-crf=False --augment-type=all  --augment-rate=0.15 
--alpha-sub=0.4 --alpha-del=0.4 --data-path=data --save-path=out

To train for Bangla the corresponding command is

python src/train.py --cuda=True --pretrained-model=xlm-roberta-large --freeze-bert=False --lstm-dim=-1 
--language=bangla --seed=1 --lr=5e-6 --epoch=10 --use-crf=False --augment-type=all  --augment-rate=0.15 
--alpha-sub=0.4 --alpha-del=0.4 --data-path=data --save-path=out

Supported models for English

bert-base-uncased
bert-large-uncased
bert-base-multilingual-cased
bert-base-multilingual-uncased
xlm-mlm-en-2048
xlm-mlm-100-1280
roberta-base
roberta-large
distilbert-base-uncased
distilbert-base-multilingual-cased
xlm-roberta-base
xlm-roberta-large
albert-base-v1
albert-base-v2
albert-large-v2

Supported models for Bangla

bert-base-multilingual-cased
bert-base-multilingual-uncased
xlm-mlm-100-1280
distilbert-base-multilingual-cased
xlm-roberta-base
xlm-roberta-large

Pretrained Models

You can find pretrained mdoels for RoBERTa-large model with augmentation for English here
XLM-RoBERTa-large model with augmentation for Bangla can be found here

Inference

You can run inference on unprocessed text file to produce punctuated text using inference module. Note that if the text already contains punctuation they are removed before inference.

Example script for English:

python inference.py --pretrained-model=roberta-large --weight-path=roberta-large-en.pt --language=en 
--in-file=data/test_en.txt --out-file=data/test_en_out.txt

This should create the text file with following output:

Tolkien drew on a wide array of influences including language, Christianity, mythology, including the Norse Völsunga saga, archaeology, especially at the Temple of Nodens, ancient and modern literature and personal experience. He was inspired primarily by his profession, philology. his work centred on the study of Old English literature, especially Beowulf, and he acknowledged its importance to his writings. 

Similarly, For Bangla

python inference.py --pretrained-model=xlm-roberta-large --weight-path=xlm-roberta-large-bn.pt --language=bn  
--in-file=data/test_bn.txt --out-file=data/test_bn_out.txt

The expected output is

বিংশ শতাব্দীর বাংলা মননে কাজী নজরুল ইসলামের মর্যাদা ও গুরুত্ব অপরিসীম। একাধারে কবি, সাহিত্যিক, সংগীতজ্ঞ, সাংবাদিক, সম্পাদক, রাজনীতিবিদ এবং সৈনিক হিসেবে অন্যায় ও অবিচারের বিরুদ্ধে নজরুল সর্বদাই ছিলেন সোচ্চার। তার কবিতা ও গানে এই মনোভাবই প্রতিফলিত হয়েছে। অগ্নিবীণা হাতে তার প্রবেশ, ধূমকেতুর মতো তার প্রকাশ। যেমন লেখাতে বিদ্রোহী, তেমনই জীবনে কাজেই "বিদ্রোহী কবি"। তার জন্ম ও মৃত্যুবার্ষিকী বিশেষ মর্যাদার সঙ্গে উভয় বাংলাতে প্রতি বৎসর উদযাপিত হয়ে থাকে। 

Please note that Comma includes commas, colons and dashes, Period includes full stops, exclamation marks and semicolons and Question is just question marks.

Test

Trained models can be tested on processed data using test module to prepare result.

For example, to test the best preforming English model run following command

python src/test.py --pretrained-model=roberta-large --lstm-dim=-1 --use-crf=False --data-path=data/test
--weight-path=weights/roberta-large-en.pt --sequence-length=256 --save-path=out

Please provide corresponding arguments for pretrained-model, lstm-dim, use-crf that were used during training the model. This will run test for all data available in data-path directory.

Cite this work

@inproceedings{alam-etal-2020-punctuation,
    title = "Punctuation Restoration using Transformer Models for High-and Low-Resource Languages",
    author = "Alam, Tanvirul  and
      Khan, Akib  and
      Alam, Firoj",
    booktitle = "Proceedings of the Sixth Workshop on Noisy User-generated Text (W-NUT 2020)",
    month = nov,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.wnut-1.18",
    pages = "132--142",
}

punctuation-restoration's People

Contributors

akibkhan619 avatar davidavdav avatar xashru avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

punctuation-restoration's Issues

Wrong shape for input_ids (shape torch.Size([1, 256])) or attention_mask (shape torch.Size([256]))

I try to use pretrained RoBERTa-large model for English, but get the following error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/site-packages/uvicorn/protocols/http/h11_impl.py", line 396, in run_asgi
    result = await app(self.scope, self.receive, self.send)
  File "/usr/local/lib/python3.6/site-packages/uvicorn/middleware/proxy_headers.py", line 45, in __call__
    return await self.app(scope, receive, send)
  File "/usr/local/lib/python3.6/site-packages/fastapi/applications.py", line 201, in __call__
    await super().__call__(scope, receive, send)  # pragma: no cover
  File "/usr/local/lib/python3.6/site-packages/starlette/applications.py", line 111, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.6/site-packages/starlette/middleware/errors.py", line 181, in __call__
    raise exc from None
  File "/usr/local/lib/python3.6/site-packages/starlette/middleware/errors.py", line 159, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.6/site-packages/starlette/exceptions.py", line 82, in __call__
    raise exc from None
  File "/usr/local/lib/python3.6/site-packages/starlette/exceptions.py", line 71, in __call__
    await self.app(scope, receive, sender)
  File "/usr/local/lib/python3.6/site-packages/starlette/routing.py", line 566, in __call__
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.6/site-packages/starlette/routing.py", line 227, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.6/site-packages/starlette/routing.py", line 41, in app
    response = await func(request)
  File "/usr/local/lib/python3.6/site-packages/fastapi/routing.py", line 202, in app
    dependant=dependant, values=values, is_coroutine=is_coroutine
  File "/usr/local/lib/python3.6/site-packages/fastapi/routing.py", line 148, in run_endpoint_function
    return await dependant.call(**values)
  File "./src/main.py", line 58, in predict
    p = inference(deep_punctuation, use_crf, sequence_length, tokenizer, token_idx, model_save_path, device, text)
  File "./src/punctuation/src/inference.py", line 49, in inference
    y_predict = deep_punctuation(x, attn_mask)
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "./src/punctuation/src/model.py", line 28, in forward
    x = self.bert_layer(x, attention_mask=attn_masks)[0]
  File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 706, in forward
    extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
  File "/usr/local/lib/python3.6/site-packages/transformers/modeling_utils.py", line 204, in get_extended_attention_mask
    input_shape, attention_mask.shape
ValueError: Wrong shape for input_ids (shape torch.Size([1, 256])) or attention_mask (shape torch.Size([256]))

Python 3.6, torch.device("cpu") and requirements.txt :

python-multipart==0.0.5
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cpu
torchaudio==0.8.1
transformers==v2.11.0
pytorch-crf==0.7.2

ERROR: Failed building wheel for tokenizers

Hello. Ty very much for installation and examples. However I am getting this error. Any idea how to solve?

I am using Windows 10 and pip

image

C:\punctuation-restoration>pip install -r requirements.txt
Collecting transformers==v2.11.0
Using cached transformers-2.11.0-py3-none-any.whl (674 kB)
Collecting pytorch-crf
Using cached pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Requirement already satisfied: sentencepiece in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (0.1.97)
Requirement already satisfied: requests in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (2.21.0)
Collecting sacremoses
Using cached sacremoses-0.0.53.tar.gz (880 kB)
Preparing metadata (setup.py) ... done
Requirement already satisfied: packaging in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (21.3)
Requirement already satisfied: filelock in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (3.8.0)
Requirement already satisfied: numpy in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (1.23.3)
Requirement already satisfied: tqdm>=4.27 in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (4.64.1)
Requirement already satisfied: regex!=2019.12.17 in c:\python399\lib\site-packages (from transformers==v2.11.0->-r requirements.txt (line 1)) (2022.9.13)
Collecting tokenizers==0.7.0
Using cached tokenizers-0.7.0.tar.gz (81 kB)
Installing build dependencies ... done
Getting requirements to build wheel ... done
Preparing metadata (pyproject.toml) ... done
Requirement already satisfied: colorama in c:\python399\lib\site-packages (from tqdm>=4.27->transformers==v2.11.0->-r requirements.txt (line 1)) (0.4.5)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in c:\python399\lib\site-packages (from packaging->transformers==v2.11.0->-r requirements.txt (line 1)) (3.0.9)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in c:\python399\lib\site-packages (from requests->transformers==v2.11.0->-r requirements.txt (line 1)) (3.0.4)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in c:\python399\lib\site-packages (from requests->transformers==v2.11.0->-r requirements.txt (line 1)) (1.24.3)
Requirement already satisfied: idna<2.9,>=2.5 in c:\python399\lib\site-packages (from requests->transformers==v2.11.0->-r requirements.txt (line 1)) (2.8)
Requirement already satisfied: certifi>=2017.4.17 in c:\python399\lib\site-packages (from requests->transformers==v2.11.0->-r requirements.txt (line 1)) (2022.9.14)
Requirement already satisfied: six in c:\python399\lib\site-packages (from sacremoses->transformers==v2.11.0->-r requirements.txt (line 1)) (1.12.0)
Collecting click
Using cached click-8.1.3-py3-none-any.whl (96 kB)
Collecting joblib
Using cached joblib-1.2.0-py3-none-any.whl (297 kB)
Building wheels for collected packages: tokenizers
Building wheel for tokenizers (pyproject.toml) ... error
error: subprocess-exited-with-error

× Building wheel for tokenizers (pyproject.toml) did not run successfully.
│ exit code: 1
╰─> [46 lines of output]
running bdist_wheel
running build
running build_py
creating build
creating build\lib.win-amd64-cpython-39
creating build\lib.win-amd64-cpython-39\tokenizers
copying tokenizers_init_.py -> build\lib.win-amd64-cpython-39\tokenizers
creating build\lib.win-amd64-cpython-39\tokenizers\models
copying tokenizers\models_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\models
creating build\lib.win-amd64-cpython-39\tokenizers\decoders
copying tokenizers\decoders_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\decoders
creating build\lib.win-amd64-cpython-39\tokenizers\normalizers
copying tokenizers\normalizers_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\normalizers
creating build\lib.win-amd64-cpython-39\tokenizers\pre_tokenizers
copying tokenizers\pre_tokenizers_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\pre_tokenizers
creating build\lib.win-amd64-cpython-39\tokenizers\processors
copying tokenizers\processors_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\processors
creating build\lib.win-amd64-cpython-39\tokenizers\trainers
copying tokenizers\trainers_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\trainers
creating build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers\implementations\base_tokenizer.py -> build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers\implementations\bert_wordpiece.py -> build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers\implementations\byte_level_bpe.py -> build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers\implementations\char_level_bpe.py -> build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers\implementations\sentencepiece_bpe.py -> build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers\implementations_init_.py -> build\lib.win-amd64-cpython-39\tokenizers\implementations
copying tokenizers_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers
copying tokenizers\models_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers\models
copying tokenizers\decoders_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers\decoders
copying tokenizers\normalizers_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers\normalizers
copying tokenizers\pre_tokenizers_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers\pre_tokenizers
copying tokenizers\processors_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers\processors
copying tokenizers\trainers_init_.pyi -> build\lib.win-amd64-cpython-39\tokenizers\trainers
running build_ext
running build_rust
error: can't find Rust compiler

  If you are using an outdated pip version, it is possible a prebuilt wheel is available for this package but pip is not able to install from it. Installing from the wheel would avoid the need for a Rust compiler.

  To update pip, run:

      pip install --upgrade pip

  and then retry package installation.

  If you did intend to build this package from source, try installing a Rust compiler from your system package manager and ensure it is on the PATH during installation. Alternatively, rustup (available at https://rustup.rs) is the recommended way to download and update the Rust compiler toolchain.
  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for tokenizers
Failed to build tokenizers
ERROR: Could not build wheels for tokenizers, which is required to install pyproject.toml-based projects

RuntimeError: Error(s) in loading state_dict for DeepPunctuation: Missing key(s) in state_dict: "bert_layer.embeddings.position_ids".

I have tried with my current installation and here the error

C:\punctuation-restoration\src>python inference.py --pretrained-model=roberta-large --weight-path=roberta-large-en.pt --language=en --in-file=data/test_en.txt --out-file=data/test_en_out.txt
C:\Python399\lib\site-packages\torchaudio\backend\utils.py:62: UserWarning: No audio backend is available.
warnings.warn("No audio backend is available.")
loading file vocab.json from cache at C:\Users\King/.cache\huggingface\hub\models--roberta-large\snapshots\5069d8a2a32a7df4c69ef9b56348be04152a2341\vocab.json
loading file merges.txt from cache at C:\Users\King/.cache\huggingface\hub\models--roberta-large\snapshots\5069d8a2a32a7df4c69ef9b56348be04152a2341\merges.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at None
loading configuration file config.json from cache at C:\Users\King/.cache\huggingface\hub\models--roberta-large\snapshots\5069d8a2a32a7df4c69ef9b56348be04152a2341\config.json
Model config RobertaConfig {
"_name_or_path": "roberta-large",
"architectures": [
"RobertaForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 514,
"model_type": "roberta",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"position_embedding_type": "absolute",
"transformers_version": "4.22.1",
"type_vocab_size": 1,
"use_cache": true,
"vocab_size": 50265
}

loading configuration file config.json from cache at C:\Users\King/.cache\huggingface\hub\models--roberta-large\snapshots\5069d8a2a32a7df4c69ef9b56348be04152a2341\config.json
Model config RobertaConfig {
"architectures": [
"RobertaForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 514,
"model_type": "roberta",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"pad_token_id": 1,
"position_embedding_type": "absolute",
"transformers_version": "4.22.1",
"type_vocab_size": 1,
"use_cache": true,
"vocab_size": 50265
}

loading weights file pytorch_model.bin from cache at C:\Users\King/.cache\huggingface\hub\models--roberta-large\snapshots\5069d8a2a32a7df4c69ef9b56348be04152a2341\pytorch_model.bin
Some weights of the model checkpoint at roberta-large were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']

  • This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
  • This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    All the weights of RobertaModel were initialized from the model checkpoint at roberta-large.
    If your task is similar to the task the model of the checkpoint was trained on, you can already use RobertaModel for predictions without further training.
    Traceback (most recent call last):
    File "C:\punctuation-restoration\src\inference.py", line 105, in
    inference()
    File "C:\punctuation-restoration\src\inference.py", line 41, in inference
    deep_punctuation.load_state_dict(torch.load(model_save_path))
    File "C:\Python399\lib\site-packages\torch\nn\modules\module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
    RuntimeError: Error(s) in loading state_dict for DeepPunctuation:
    Missing key(s) in state_dict: "bert_layer.embeddings.position_ids".

C:\punctuation-restoration\src>

What if I want to load a pre-trained model from local disk?

Hello,

I like to run this model at the local GPU server with no internet.
After downloading one of the Huggingface pytorch models outside and moving it into the inside local disk, I do not know how to change path information. Please let me have the clue.

Thanks,

i have created my own training data i got this error

2021-04-23 21:22:11.699111: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
Traceback (most recent call last):
File "src/train.py", line 41, in
token_style=token_style, is_train=True, augment_rate=ar, augment_type=aug_type)
File "/content/drive/My Drive/punctuation-restoration/src/dataset.py", line 77, in init
self.data = parse_data(files, tokenizer, sequence_len, token_style)
File "/content/drive/My Drive/punctuation-restoration/src/dataset.py", line 45, in parse_data
y.append(punctuation_dict[punc])
KeyError: '0'

HELP PLEASE.

FileNotFoundError

While running the test script:
python src/test.py --pretrained-model=roberta-large --lstm-dim=-1 --use-crf=False --data-path=data/test --weight-path=weights/roberta-large-en.pt --sequence-length=256 --save-path=out

I'm getting the following error:

Traceback (most recent call last):
File "src/test.py", line 33, in
test_files = os.listdir(args.data_path)
FileNotFoundError: [Errno 2] No such file or directory: 'data/test

CUDA memory allocation error!

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 11.17 GiB total capacity; 10.41 GiB already allocated; 5.81 MiB free; 10.54 GiB reserved in total by PyTorch)

How to set limit for the memory allocation by CUDA?

I have tried this too: export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128

But still it is failing!

StackTrace:

raceback (most recent call last):
  File "src/train.py", line 264, in <module>
    train()
  File "src/train.py", line 211, in train
    y_predict = deep_punctuation(x, att)
  File "/root/banglaDariComma/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/punctuation-restoration/src/model.py", line 28, in forward
    x = self.bert_layer(x, attention_mask=attn_masks)[0]
  File "/root/banglaDariComma/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/banglaDariComma/lib/python3.7/site-packages/transformers/modeling_bert.py", line 734, in forward
    encoder_attention_mask=encoder_extended_attention_mask,
  File "/root/banglaDariComma/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/banglaDariComma/lib/python3.7/site-packages/transformers/modeling_bert.py", line 408, in forward
    hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
  File "/root/banglaDariComma/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/banglaDariComma/lib/python3.7/site-packages/transformers/modeling_bert.py", line 369, in forward
    self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
  File "/root/banglaDariComma/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/banglaDariComma/lib/python3.7/site-packages/transformers/modeling_bert.py", line 315, in forward
    hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
  File "/root/banglaDariComma/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/root/banglaDariComma/lib/python3.7/site-packages/transformers/modeling_bert.py", line 236, in forward
    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

Unable to predict using inference.py file

I tried using the inference.py file to predict a sample text with 100 as sequence length, I get the following run time error. Kindly assist me on how this could be solved,

python inference.py --pretrained-model=roberta-large --weight-path=roberta-large-en.pt --language=en --in-file=data/test_en.txt --out-file=data/test_en_out.txt

Traceback (most recent call last):
File "src/inference.py", line 106, in
inference()
File "src/inference.py", line 42, in inference
deep_punctuation.load_state_dict(torch.load(model_save_path))
File "C:\Users\SMVP\Anaconda3\envs\my_torch\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DeepPunctuation:
Missing key(s) in state_dict: "bert_layer.embeddings.position_ids".

What is 5th Index value in Precision, Recall and F1 lists

Hi,
I got the below values for one of the tests.

Precision: [0.98068136 0.71008241 0.87272477 0.83011424 0.81440205]
Recall: [0.98558426 0.67705532 0.84840781 0.84434041 0.79418293]
F1 score: [0.9831267 0.69317568 0.86039451 0.83716689 0.80416542]
Accuracy:0.9535017031737457
Confusion Matrix
[[297677 2467 1236 651]
[ 3773 12839 1806 545]
[ 1366 2295 25364 871]
[ 725 480 657 10100]]

On manually calculating, first four indexes of Precision. Recall and F1 score correspond to 4 punctuations tried. But what does 5th one refer to? Its definitely not Average of the first 4 to represent "Overall" value.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.