Giter Club home page Giter Club logo

trocr's Introduction

Handwritten Character Recognition - an unofficial implementation of the paper

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models


This is an unofficial implementation of TrOCR based on the Hugging Face transformers library and the TrOCR paper. There is also a repository by the authors of the paper (link). The code in this repository is merely a more simple wrapper to quickly get started with training and deploying this model for character recognition tasks.

 

Results:

Predictions

After training on a dataset of 2000 samples for 8 epochs, we got an accuracy of 96,5%. Both the training and the validation datasets were not completely clean. Otherwise, even higher accuracies would have been possible.

 

Architecture:

TrOCR (TrOCR architecture. Taken from the original paper.)

TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models, Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei, Preprint 2021.

 

 

 



 

1. Setup

Clone the repository and make sure to have conda or miniconda installed. Then go into the directory of the cloned repository and run

conda env create -n trocr --file environment.yml
conda activate trocr

This should install all necessary libraries.

Training without GPU:

It is highly recommended to use a CUDA GPU, but everything also works on cpu. For that, install from file environment-cpu.yml instead.

In case the process terminates with the warning "killed", reduce the batch size to fit into the working memory.

 

2. Using the repository

There are 3 modes, inference, validation and training. All 3 of them can either start with a local model in the right path (see src/constants/paths) or with the pretrained model from huggingface. Inference and Validation use the local model per default, training starts with the huggingface model per default.

 

Inference (Prediction):

python -m src predict <image_files>  # predict image files using the trained local model
python -m src predict data/img1.png data/img2.png  # list all image files
python -m src predict data/*  # also works with shell expansion
python -m src predict data/* --no-local-model  # uses the pretrained huggingface model

Validation:

python -m src validate # uses pretrained local model
python -m src validate --no-local-model # loads pretrained model from huggingface

Training:

python -m src train  # starts with pretrained model from huggingface
python -m src train --local-model  # starts with pretrained local model

 

For validation and training, input images should be in directories train and val and the labels should be in gt/labels.csv. In the csv, each row should consist of image name and then ending, for example img1.png,a (in quotes, if necessary).

It is also pretty straightforward to read labels from somewhere else. For that, just add the necessary code to load_filepaths_and_labels in src/dataset.py.

For choosing a subsample of the train data as validation data, this command can be used

find train -type f | shuf -n <num of val samples> | xargs -I '{}' mv {} val

 

3. Integrating into other projects

If you want to use the predictions as part of a bigger project, you can just use the interface provided by the TrocrPredictor in main. For that make sure to run all code as python modules.

See the following code example:

from PIL import Image
from trocr.src.main import TrocrPredictor

# load images
image_names = ["data/img1.png", "data/img2.png"]
images = [Image.open(img_name) for img_name in image_names]

# directly predict on Pillow Images or on file names
model = TrocrPredictor()
predictions = model.predict_images(images)
predictions = model.predict_for_file_names(image_names)

# print results
for i, file_name in enumerate(image_names):
    print(f'Prediction for {file_name}: {predictions[i]}')

 

4. Adapting the Code

In general, it should be easy to adapt the code for other input formats or use cases.

  • Learning Rate, Batch size, Train Epoch Count, Logging, Word Len: src/configs/constants.py
  • Input Paths, Model Checkpoint Path: src/configs/paths.py
  • Different label format: src/dataset.py : load_filepaths_and_labels

The word len constant is very important. To facilitate batch training, all labels need to be padded to the same length. Some experiments might be needed here. For us, padding to 8 worked well.

If you want to change specifics of the model, you can supply a TrOCRConfig object to the transformers interface. See https://huggingface.co/docs/transformers/model_doc/trocr#transformers.TrOCRConfig for more details.

 

5. Contact

If the setup fails to work, please let me know in a Github issue! Sometimes sub-dependencies update and become incompatible with other dependencies, so the dependency list has to be updated.

Feel free to submit issues with questions about the implementation as well.

For questions about the paper or the architecture, please get in touch with the authors.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.