Giter Club home page Giter Club logo

bert-sklearn's Introduction

scikit-learn wrapper to finetune BERT

A scikit-learn wrapper to finetune Google's BERT model for text and token sequence tasks based on the huggingface pytorch port.

  • Includes configurable MLP as final classifier/regressor for text and text pair tasks
  • Includes token sequence classifier for NER, PoS, and chunking tasks
  • Includes SciBERT and BioBERT pretrained models for scientific and biomedical domains.

Try in Google Colab!

installation

requires python >= 3.5 and pytorch >= 0.4.1

git clone -b master https://github.com/charles9n/bert-sklearn
cd bert-sklearn
pip install .

basic operation

model.fit(X,y) i.e finetune BERT

  • X: list, pandas dataframe, or numpy array of text, text pairs, or token lists

  • y : list, pandas dataframe, or numpy array of labels/targets

from bert_sklearn import BertClassifier
from bert_sklearn import BertRegressor
from bert_sklearn import load_model

# define model
model = BertClassifier()         # text/text pair classification
# model = BertRegressor()        # text/text pair regression
# model = BertTokenClassifier()  # token sequence classification

# finetune model
model.fit(X_train, y_train)

# make predictions
y_pred = model.predict(X_test)

# make probabilty predictions
y_pred = model.predict_proba(X_test)

# score model on test data
model.score(X_test, y_test)

# save model to disk
savefile='/data/mymodel.bin'
model.save(savefile)

# load model from disk
new_model = load_model(savefile)

# do stuff with new model
new_model.score(X_test, y_test)

See demo notebook.

model options

# try different options...
model.bert_model = 'bert-large-uncased'
model.num_mlp_layers = 3
model.max_seq_length = 196
model.epochs = 4
model.learning_rate = 4e-5
model.gradient_accumulation_steps = 4

# finetune
model.fit(X_train, y_train)

# do stuff...
model.score(X_test, y_test)

See options

hyperparameter tuning

from sklearn.model_selection import GridSearchCV

params = {'epochs':[3, 4], 'learning_rate':[2e-5, 3e-5, 5e-5]}

# wrap classifier in GridSearchCV
clf = GridSearchCV(BertClassifier(validation_fraction=0), 
                    params,
                    scoring='accuracy',
                    verbose=True)

# fit gridsearch 
clf.fit(X_train ,y_train)

See demo_tuning_hyperparameters notebook.

GLUE datasets

The train and dev data sets from the GLUE(Generalized Language Understanding Evaluation) benchmarks were used with bert-base-uncased model and compared againt the reported results in the Google paper and GLUE leaderboard.

MNLI(m/mm) QQP QNLI SST-2 CoLA STS-B MRPC RTE
BERT base(leaderboard) 84.6/83.4 89.2 90.1 93.5 52.1 87.1 84.8 66.4
bert-sklearn 83.7/83.9 90.2 88.6 92.32 58.1 89.7 86.8 64.6

Individual runs can be found can be found here.

CoNLL-2003 Named Entity Recognition(NER)

NER results for CoNLL-2003 shared task

dev f1 test f1
BERT paper 96.4 92.4
bert-sklearn 96.04 91.97

Span level stats on test:

processed 46666 tokens with 5648 phrases; found: 5740 phrases; correct: 5173.
accuracy:  98.15%; precision:  90.12%; recall:  91.59%; FB1:  90.85
              LOC: precision:  92.24%; recall:  92.69%; FB1:  92.46  1676
             MISC: precision:  78.07%; recall:  81.62%; FB1:  79.81  734
              ORG: precision:  87.64%; recall:  90.07%; FB1:  88.84  1707
              PER: precision:  96.00%; recall:  96.35%; FB1:  96.17  1623

See ner_english notebook for a demo using 'bert-base-cased' model.

NCBI Biomedical NER

NER results using bert-sklearn with SciBERT and BioBERT on the the NCBI disease Corpus name recognition task.

Previous SOTA for this task is 87.34 for f1 on the test set.

test f1 (bert-sklearn) test f1 (from papers)
BERT base cased 85.09 85.49
SciBERT basevocab cased 88.29 86.91
SciBERT scivocab cased 87.73 86.45
BioBERT pubmed_v1.0 87.86 87.38
BioBERT pubmed_pmc_v1.0 88.26 89.36
BioBERT pubmed_v1.1 87.26 NA

See ner_NCBI_disease_BioBERT_SciBERT notebook for a demo using SciBERT and BioBERT models.

See SciBERT paper and BioBERT paper for more info on the respective models.

Other examples

tests

Run tests with pytest :

python -m pytest -sv tests/

references

bert-sklearn's People

Contributors

charles9n avatar lucasnil avatar fabiotph avatar ezesalta avatar

Stargazers

Fernando Rezende Zagatti avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.