Giter Club home page Giter Club logo

Comments (3)

hiyouga avatar hiyouga commented on August 15, 2024

Thanks for your question. We perform random shuffling on the labels to mitigate the bias brought by the position embeddings in the BERT models. In other words, we make the label representations irrelevant to the orders. This operation will change the real label neither in binary classification nor in multi-class classification.
The DualCL performs well on the sentiment classification and the question classification tasks, but the performance on the other tasks is yet to be confirmed. The unsatisfied results on the dialogue intention recognition task may be due to the nature of such a complicated task.

from dual-contrastive-learning.

OPilgrim avatar OPilgrim commented on August 15, 2024

Thanks for your reply!
Well, I must point out that multiple classification may indeed be problematic...
Here's how you do it:

class MyDataset(Dataset):
    def __init__(self, raw_data, label_dict, tokenizer, model_name, method):
        dataset = list()
        for data in raw_data:
            tokens = data['text'].lower().split(' ')
            label_id = label_dict[data['label']]
            dataset.append((tokens, label_id))
        self._dataset = dataset
        self._method = method
        self.sep_token = ['[SEP]'] if model_name == 'bert' else ['</s>']
        self._label_list = list(label_dict.keys())
        self._num_classes = len(label_dict)

    def __getitem__(self, index):
        tokens, label_id = self._dataset[index]
        if self._method not in ['ce', 'scl']:
            rand_idx = [i for i in range(self._num_classes)]
            random.shuffle(rand_idx)
            label_list = [self._label_list[i] for i in rand_idx]
            tokens = label_list + self.sep_token + tokens
            label_id = rand_idx[label_id]
        return tokens, label_id

    def __len__(self):
        return len(self._dataset)

Assuming we have 6 categories, then self._label_list assumes this: ["1", "2", "3", "4", "5", "6"]. Rand_idx should be [0,1, 2, 3, 4, 5], assuming shuffle is followed by [3, 5, 0,1, 4, 2], then label_list should be ["4", "6", "1", "2", "5", "3"]. Because label_id = rand_idx[label_id], if the original label_id is 0, the corresponding label is "1", and the current label_id is changed to rand_idx[0]=3, the corresponding label is "2"......Doesn't that make the label wrong?

from dual-contrastive-learning.

hiyouga avatar hiyouga commented on August 15, 2024

Thanks very much! Exactly it was problematic, we have removed the random shuffling and assigned all the position embeddings of the label tokens as zero. Therefore, the model's prediction is independent of the label order. The implementation has been updated.

class MyDataset(Dataset):

    def __init__(self, raw_data, label_dict, tokenizer, model_name, method):
        label_list = list(label_dict.keys()) if method not in ['ce', 'scl'] else []
        sep_token = ['[SEP]'] if model_name == 'bert' else ['</s>']
        dataset = list()
        for data in raw_data:
            tokens = data['text'].lower().split(' ')
            label_id = label_dict[data['label']]
            dataset.append((label_list + sep_token + tokens, label_id))
        self._dataset = dataset

    def __getitem__(self, index):
        return self._dataset[index]

    def __len__(self):
        return len(self._dataset)


def my_collate(batch, tokenizer, method, num_classes):
    tokens, label_ids = map(list, zip(*batch))
    text_ids = tokenizer(tokens,
                         padding=True,
                         truncation=True,
                         max_length=256,
                         is_split_into_words=True,
                         add_special_tokens=True,
                         return_tensors='pt')
    if method not in ['ce', 'scl']:
        positions = torch.zeros_like(text_ids['input_ids'])
        positions[:, num_classes:] = torch.arange(0, text_ids['input_ids'].size(1)-num_classes)
        text_ids['position_ids'] = positions
    return text_ids, torch.tensor(label_ids)

from dual-contrastive-learning.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.