Giter Club home page Giter Club logo

Comments (3)

ningding97 avatar ningding97 commented on June 4, 2024

Hi, thank you for sharing the data and code!

I just found that it seems that an input word is not correctly tokenized by the word tokenizer:

in the word_tokenizer.py file
Each word is directly converted to token id

for raw_tokens in raw_tokens_list:
     indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)

However, a word could be tokenized into word pieces by

for raw_tokens in raw_tokens_list:
     for word in raw_tokens:
            word_tokens = self.tokenizer.tokenize(word)

Directly converting word to token id will lead to lots of [UNK] and make the performance drop a lot.

Thanks, we will fix it soon.

from few-nerd.

chandan047 avatar chandan047 commented on June 4, 2024

Hi. I have implemented the tokenization function. The performance for the prototypical network in (inter) 5 way 5~10 shot setting bumped up from 52.42 to 60.09. Similarly, for (inter) 5 way 1~2 shot performance bumped from 37.49 to 44.45. These are from single runs only but I think it is worth changing the tokenization.

Here is my implementation

def tokenize(self, raw_tokens, tags):
        raw_tokens = [token.lower() for token in raw_tokens]
        indexed_tokens_list = []
        tag_list = []
        mask_list = []
        text_mask_list = []

        curr_split = ["[CLS]"]
        tag_split = []
        mask_split = np.zeros((self.max_length), dtype=np.int32)
        text_mask_split = np.zeros((self.max_length), dtype=np.int32)
        
        for i, (word, tag) in enumerate(zip(raw_tokens, tags)):
            tokens = self.tokenizer.tokenize(word)
            
            if len(curr_split) + len(tokens) >= self.max_length:
                indexed_tokens = self.tokenizer.convert_tokens_to_ids(curr_split + ['[SEP]'])
                while len(indexed_tokens) < self.max_length:
                    indexed_tokens.append(0)
                mask_split[:len(indexed_tokens)] = 1
                
                indexed_tokens_list.append(indexed_tokens)
                tag_list.append(tag_split)
                mask_list.append(mask_split)
                text_mask_list.append(text_mask_split)
                
                curr_split = ['[CLS]']
                tag_split = []
                mask_split = np.zeros((self.max_length), dtype=np.int32)
                text_mask_split = np.zeros((self.max_length), dtype=np.int32)
            
            text_mask_split[len(curr_split)] = 1
            curr_split.extend(tokens)
            tag_split.append(tag)
                
        
        if tag_split:
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(curr_split + ['[SEP]'])
            while len(indexed_tokens) < self.max_length:
                indexed_tokens.append(0)
            mask_split[:len(indexed_tokens)] = 1

            indexed_tokens_list.append(indexed_tokens)
            tag_list.append(tag_split)
            mask_list.append(mask_split)
            text_mask_list.append(text_mask_split)
        
        return indexed_tokens_list, mask_list, text_mask_list, tag_list

from few-nerd.

ningding97 avatar ningding97 commented on June 4, 2024

Hi. I have implemented the tokenization function. The performance for the prototypical network in (inter) 5 way 5~10 shot setting bumped up from 52.42 to 60.09. Similarly, for (inter) 5 way 1~2 shot performance bumped from 37.49 to 44.45. These are from single runs only but I think it is worth changing the tokenization.

Here is my implementation

def tokenize(self, raw_tokens, tags):
        raw_tokens = [token.lower() for token in raw_tokens]
        indexed_tokens_list = []
        tag_list = []
        mask_list = []
        text_mask_list = []

        curr_split = ["[CLS]"]
        tag_split = []
        mask_split = np.zeros((self.max_length), dtype=np.int32)
        text_mask_split = np.zeros((self.max_length), dtype=np.int32)
        
        for i, (word, tag) in enumerate(zip(raw_tokens, tags)):
            tokens = self.tokenizer.tokenize(word)
            
            if len(curr_split) + len(tokens) >= self.max_length:
                indexed_tokens = self.tokenizer.convert_tokens_to_ids(curr_split + ['[SEP]'])
                while len(indexed_tokens) < self.max_length:
                    indexed_tokens.append(0)
                mask_split[:len(indexed_tokens)] = 1
                
                indexed_tokens_list.append(indexed_tokens)
                tag_list.append(tag_split)
                mask_list.append(mask_split)
                text_mask_list.append(text_mask_split)
                
                curr_split = ['[CLS]']
                tag_split = []
                mask_split = np.zeros((self.max_length), dtype=np.int32)
                text_mask_split = np.zeros((self.max_length), dtype=np.int32)
            
            text_mask_split[len(curr_split)] = 1
            curr_split.extend(tokens)
            tag_split.append(tag)
                
        
        if tag_split:
            indexed_tokens = self.tokenizer.convert_tokens_to_ids(curr_split + ['[SEP]'])
            while len(indexed_tokens) < self.max_length:
                indexed_tokens.append(0)
            mask_split[:len(indexed_tokens)] = 1

            indexed_tokens_list.append(indexed_tokens)
            tag_list.append(tag_split)
            mask_list.append(mask_split)
            text_mask_list.append(text_mask_split)
        
        return indexed_tokens_list, mask_list, text_mask_list, tag_list

Thanks, this is very helpful, we will update the results soon

from few-nerd.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.