Giter Club home page Giter Club logo

Comments (1)

RajK853 avatar RajK853 commented on August 19, 2024

We use the label vocabulary from the Gector available here. For our project, we do not need the @@UNKNOWN@@ and @@PADDING@@ labels since they are added automatically because of Gector's implementation of the Allennlp library.

Weighted Sampler

The label vocabulary contains 5,000 labels for 9 categories; keep, delete, replace, append, merge, transform case, transform verb, transform agreement and transform split. The number of labels in each category is shown in the graph below.

image

Randomly sampling the action label will be inefficient for the following reasons:

  1. The probability of a random correct guess is really low because of the high action space.
  2. Some labels like keep, replace and append occur more often than other labels.
  3. Randomly deleting tokens will lose the information from the original sentence.
  4. Randomly appending tokens will change the context of the original sentence.

Therefore, we created a weighted sampler that generates some action labels more often than others.

DRL-GEC/src/sampler.py

Lines 5 to 48 in 2e9aff2

class WeightedSampler:
def __init__(self, labels, weight_dict):
self.labels = labels
self.weight_dict = weight_dict
self.num_labels = len(self.labels)
self.labels_freq = self.gen_label_dist(labels)
# Initialize label probabilities
self.label_probs = None
self.init_weights()
def init_weights(self, weight_dict=None):
if weight_dict:
self.weight_dict = weight_dict
# Normalize weights based on the number of labels i.e. some category like REPLACE or APPEND has multiple labels
normalized_weights = {k: v / self.labels_freq[k] for k, v in self.weight_dict.items()}
weights = np.ones(self.num_labels, dtype="float32")
for i, label in enumerate(self.labels):
for label_prefix, weight in normalized_weights.items():
if label.startswith(label_prefix):
weights[i] = weight
break
self.label_probs = weights / sum(weights) # Ensure sum(label_probs) == 1.0
@staticmethod
def gen_label_dist(label_list):
label_dist = defaultdict(int)
for label in label_list:
if label in ("$KEEP", "$DELETE"):
label_dist[label] += 1
elif any(label.startswith(k) for k in ("$REPLACE_", "$APPEND_", "$MERGE_")):
label_type, *_ = label.split("_")
label_dist[label_type] += 1
elif label.startswith("$TRANSFORM_"):
label_type = "_".join(label.split("_")[:2])
label_dist[label_type] += 1
else:
raise NotImplementedError(f"Cannot handle {label} label!")
return label_dist
def sample(self, size=1):
return np.random.choice(self.num_labels, size=size, p=self.label_probs)
def __call__(self, size=1):
return self.sample(size)

Test

For the test, we used the following label weights in the WeightedSampler:

{
    "$KEEP": 3.0,
    "$MERGE": 0.1,
    "$DELETE": 0.3,
    "$APPEND": 0.5,
    "$REPLACE": 1.0, 
    "$TRANSFORM_SPLIT": 0.2,
    "$TRANSFORM_CASE": 1.0,
    "$TRANSFORM_VERB": 1.0,
    "$TRANSFORM_AGREEMENT": 1.0,
}

Using the above weights, we sampled different amounts of samples.

image

The sample above depicts the scenario where we randomly sample the action labels for a sentence with 10 tokens. Because of the high weight on the keep label, most of the random labels are in this category.

image

The sample above depicts the likelihood of sampling different action labels during the training stage.

from drl-gec.

Related Issues (5)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.