Giter Club home page Giter Club logo

style-knnlm's Introduction

k-Nearest-Neighbor Language Models with Style Attributes

This repository is a fork of the knnlm repository (see commit) and adds support for style attributes.

How to use

Preprocessing is exactly the same as knnlm.

LM training, LM evaluation, training of FAISS indices and kNN-LM evaluation have some changed or additional parameters.

Data format for style attributes

Style attribute files follow the format of raw text files {name}.{split}.tokens:

  • name should be {name}.{split}.style
  • values are separated by linebreaks, i.e. 1 value per line.
  • 1 value for each line in the raw text.

LM training

Training parameters for all previous models remain unchanged. For style attribute support we add the architecture transformer_lm_style as extension of transformer_lm_wiki103.

Following the example in knnlm:

python train.py \
    $BIN \
    --task language_modeling \
    --save-dir checkpoints/ \
    --arch transformer_lm_style \
    --style-input-dim 1 \
    --style-embed-dim 32 \
    --style-path $TEXT \
    --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
    --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
    --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 --fp16 \
    --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
  • --style-input-dim: the dimension of style attributes. Currently only 1 is supported.
  • --style-embed-dim: the embedding dimension of style attributes in the LM.
  • --style-path: the folder where style attribute files are saved.

LM evaluation

python eval_lm.py \
    $BIN \
    --style-path $TEXT \
    --path checkpoints/checkpoint_best.pt \
    --sample-break-mode complete --max-tokens 3072 \
    --context-window 2560 --softmax-batch 1024 \
    --gen-subset valid

Populating the datastore

python eval_lm.py \
    $BIN \
    --path checkpoints/checkpoint_best.pt \
    --style-path $TEXT \
    --sample-break-mode none --max-tokens 3072 \
    --softmax-batch 1024 --gen-subset train \
    --context-window 1536 --tokens-per-sample 1536 \
    --dstore-mmap checkpoints/dstore --knn-keytype 'last_ffn_input' \
    --model-overrides "{'knn_keytype': 'last_ffn_input'}" \
    --save-knnlm-dstore --dstore-save-style \
    --fp16 --dstore-fp16
  • The --dstore-size parameter was removed, since we calculate the required size automatically and embed shape and dtype into the memmap.
  • We added an argument --dstore-fp16 to enable saving in half precision.
  • Use --dstore-save-style to save style attributes for the keys and values in the datastore. This is necessary for style metrics during evaluation.

Training the FAISS index

The index training script build_index.py was rewritten and many parameters were renamed. Use the --help argument for more info.

python build_index.py \
    --dstore_mmap checkpoints/dstore \
    --filepath checkpoints/knn.index \
    --batch-size-add 1000000 \
    --batch-size-write 5000000 \
    --starting_point 0 \
    --move-dstore-to-mem \
    --dstore-fp16

kNN-LM evaluation

python eval_lm.py \
    $BIN \
    --path checkpoints/checkpoint_best.pt \
    --style-path $TEXT \
    --sample-break-mode complete --max-tokens 3072 \
    --context-window 2560 --softmax-batch 1024 \
    --gen-subset valid --dstore-mmap checkpoints/dstore \
    --indexfile checkpoints/knn.index  \
    --model-overrides "{'knn_keytype': 'last_ffn_input'}" \
    --k 1024 --lmbda 0.25 --knn-keytype last_ffn_input \
    --probe 32 --knnlm --fp16 --move-dstore-to-mem

Custom metrics

We added some custom evaluation metrics to eval_lm.py, which can be enabled separately with --report-metrics, and reported during/after evaluation.

  • style: Style MAE/MBE. The mean absolute/bias error of retrieved style vs. requested style.
  • topk-ds-precision: Top-k datastore retrieval precision. Requires the additional parameter --top-k to be set.
  • Top-k LM precision of probabilities. If --knnlm is used, all three probabilites will be used (LM, datastore, interpolated). Otherwise only LM probabilities will be used. Requires the additional parameter --top-k to be set.

Saving intermediate variables

To support saving intermediate variables we adapt some of efficient-knnlm's code.

Variables can be saved to --save-vars-dir with --save-vars. Options for --save-vars are:

  • predictions: The top-k predictions (requires --top-k).
    If --knnlm is used, this includes predictions from LM,- datastore- and interpolated probabilities. Otherwise only from LM probabilities.
  • dictionary: the vocabulary used
  • knns: the indices of retrieved kNNs
  • dists: the distances of the retrieved kNNs to the queries
  • reftargets: the reference targets
  • refstyle: the reference style
  • style: the retrieved style
  • correctness: Boolean mask where the retrieved kNNs match the reference target token.

Known issues

  • Calculating/saving top-k probabilities fails with large models due to CUDA OOM errors.

style-knnlm's People

Contributors

myleott avatar alexeib avatar liezl200 avatar louismartin avatar cndn avatar edunov avatar huihuifan avatar kartikayk avatar freewym avatar multipath avatar pipibjc avatar urvashik avatar theweiho avatar lematt1991 avatar jma127 avatar kahne avatar stephenroller avatar nng555 avatar jhcross avatar hitvoice avatar ngoyal2707 avatar maigoakisame avatar xianxl avatar jingfeidu avatar halilakin avatar davidecaroselli avatar rutyrinott avatar taylanbil avatar erip avatar skritika avatar

Stargazers

Charlie Welch avatar

Watchers

Charlie Welch avatar  avatar

style-knnlm's Issues

dtype mismatch in eval_lm.py

Description

kNN-evaluation fails because of dtype mismatch in eval_lm.py.

Error

Traceback (most recent call last):                                              
  File "/home/d8xa/knnlm/eval_lm.py", line 11, in <module>
    cli_main()
  File "/home/d8xa/knnlm/fairseq_cli/eval_lm.py", line 289, in cli_main
    main(args)
  File "/home/d8xa/knnlm/fairseq_cli/eval_lm.py", line 178, in main
    hypos = scorer.generate(models, sample, knn_dstore=knn_dstore)
  File "/home/d8xa/anaconda3/envs/nlp/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/d8xa/knnlm/fairseq/sequence_scorer.py", line 101, in generate
    yhat_knn_prob = dstore.get_knn_log_prob(
  File "/home/d8xa/knnlm/fairseq/knnlm.py", line 114, in get_knn_log_prob
    full_yhat_knn_prob[tgt != pad_idx] = yhat_knn_prob
RuntimeError: Index put requires the source and destination dtypes match, got Long for the destination and Float for the source.

Steps to reproduce

The following command was used:

python eval_lm.py $DEST \
     --path $CHECKPOINTS/$MODELNAME \
     --sample-break-mode complete --max-tokens 3072 \
     --context-window 1536 --softmax-batch 1024 \
     --gen-subset valid --dstore-filename $CHECKPOINTS/dstore \
     --indexfile $CHECKPOINTS/knn.index  \
     --model-overrides "{'knn_keytype': 'last_ffn_input'}" \
     --k 128 --lmbda 0.25 --dstore-size 103225485 --knn-keytype last_ffn_input \
     --probe 32 --knnlm --fp16

But theoretically the error should occur regardless of parameters -- any time the code reaches this segment:

yhat_knn_prob = torch.logsumexp(probs + index_mask, dim=-1).clone()
full_yhat_knn_prob = torch.full([qshape[0]*qshape[1]], -10000).cuda()
full_yhat_knn_prob[tgt != pad_idx] = yhat_knn_prob

yhat_knn_prob is created as Float, while full_yhat_knn_prob is created as Long.

Unit tests fail

The current unit tests largely don't support knnlm implementation.

Saving dstore values fails in last sample if tokens don't fit

Description

Values in the last sample don't fit in dstore file.

Error

Traceback (most recent call last):                                              
  File "/home/d8xa/knnlm/eval_lm.py", line 11, in <module>
    cli_main()
  File "/home/d8xa/knnlm/fairseq_cli/eval_lm.py", line 289, in cli_main
    main(args)
  File "/home/d8xa/knnlm/fairseq_cli/eval_lm.py", line 194, in main
    dstore_vals[dstore_idx:shape[0]+dstore_idx] = hypo['tokens'].view(
ValueError: could not broadcast input array from shape (10,1) into shape (1,1)

Relevant code segment:

if dstore_idx + shape[0] > args.dstore_size:
shape = [args.dstore_size - dstore_idx]
hypo['dstore_keys'] = hypo['dstore_keys'][:shape[0]]
if args.dstore_fp16:
dstore_keys[dstore_idx:shape[0]+dstore_idx] = hypo['dstore_keys'].view(
-1, args.decoder_embed_dim).cpu().numpy().astype(np.float16)
dstore_vals[dstore_idx:shape[0]+dstore_idx] = hypo['tokens'].view(
-1, 1).cpu().numpy().astype(np.int16)

While hypo['dstore_keys'] is truncated if necessary, hypo['tokens'] is not.

Improve mmap usage in datastore

Feature Request

Use np.save to handle memory-mapped files.

Motivation

Currently the datastore is saved and loaded with np.memmap which requires the user to specify the dimensions.

Pitch

Using np.save to memmap files removes the need to manually specify dimensions of the datastore because they will be stored in the file. Cf. https://stackoverflow.com/a/36749821.

UnboundLocalError in train.py when all samples in a batch are skipped

2021-12-13 04:38:51 | WARNING | fairseq.data.data_utils | 24065 samples have invalid sizes and will be skipped, max_positions=20, first few sample ids=[24064, 22597, 4716, 10717, 22414, 10206, 11045, 1925, 15812, 10117]
Traceback (most recent call last):
  File "/home/d8xa/knnlm/train.py", line 11, in <module>
    cli_main()
  File "/home/d8xa/knnlm/fairseq_cli/train.py", line 307, in cli_main
    main(args)
  File "/home/d8xa/knnlm/fairseq_cli/train.py", line 102, in main
    train(args, trainer, task, epoch_itr)
  File "/home/d8xa/anaconda3/envs/nlp/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/d8xa/knnlm/fairseq_cli/train.py", line 194, in train
    progress.print(stats, tag='train', step=num_updates)
UnboundLocalError: local variable 'num_updates' referenced before assignment

LMContextWindowDataset has incorrect ordered indices when using different break mode

Description

MonolingualDataset orders indices either by sample length if shuffle is disabled, or randomly:

def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)

LMContextWindow ignores this:

def ordered_indices(self):
# NOTE we don't shuffle the data to retain access to the previous dataset elements
return np.arange(len(self.dataset))

When using --sample-break-mode none, none of the samples with identical size will be reordered. Except the last one, if it has a less tokens than tokens-per-sample. Consider this example:

# sample indices and sizes in TokenBlockDataset
array([     0,      1,      2, ..., 148836, 148837, 148838])
array([20, 20, 20, ..., 20, 20, 12])

# resulting ordered indices produced by MonolingualDataset:
array([148838,      0,      1, ..., 148835, 148836, 148837])

# LMContextWindowDataset then overwrites this, resulting in:
array([     0,      1,      2, ..., 148836, 148837, 148838])

But for a different sample break mode (e.g. complete) or when shuffling samples, the result is not as expected. Consider this example using --sample-break-mode complete_doc:

# Sample indices and sizes in TokenBlockDataset:
array([     0,      1,      2, ..., 137507, 137508, 137509])
array([11, 15, 13, ..., 17,  8, 13])

# resulting ordered indices produced by MonolingualDataset:
array([   22,    75,    93, ..., 93948, 60582, 98161])

# in LMContextWindowDataset:
array([     0,      1,      2, ..., 137507, 137508, 137509])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.