Giter Club home page Giter Club logo

Comments (2)

linjieli222 avatar linjieli222 commented on July 17, 2024 1

Looks like an index out of range error. I suspect that it happens in input_ids, type_ids or position_ids.

  1. input_ids for VCR are augmented with 81 special tokens, directly applying word embeddings in pretrain.py may not work.
    Please check VCR model to see how to add special tokens:

    UNITER/model/vcr.py

    Lines 43 to 51 in 8b8e181

    def init_word_embedding(self, num_special_tokens):
    orig_word_num = self.uniter.embeddings.word_embeddings.weight.size(0)
    new_emb = nn.Embedding(
    orig_word_num + num_special_tokens, self.uniter.config.hidden_size)
    new_emb.apply(self.init_weights)
    emb = self.uniter.embeddings.word_embeddings.weight.data
    new_emb.weight.data[:orig_word_num, :].copy_(emb)
    self.uniter.embeddings.word_embeddings = new_emb

  2. text_type_ids for VCR are set to 0 for question, 2 for answers and 3 for rationales. image_type_ids are set to 1.
    Pretraining model only have 2 type ids, 0 for text and 1 for image. Please check VCR model to see how to add type ids:

    UNITER/model/vcr.py

    Lines 31 to 42 in 8b8e181

    def init_type_embedding(self):
    new_emb = nn.Embedding(4, self.uniter.config.hidden_size)
    new_emb.apply(self.init_weights)
    for i in [0, 1]:
    emb = self.uniter.embeddings.token_type_embeddings.weight.data[i, :]
    new_emb.weight.data[i, :].copy_(emb)
    emb = self.uniter.embeddings.token_type_embeddings.weight.data[0, :]
    new_emb.weight.data[2, :].copy_(emb)
    new_emb.weight.data[3, :].copy_(emb)
    self.uniter.embeddings.token_type_embeddings = new_emb

  3. position_ids should be in the range of [0, 511].

You will also need to call the above two functions in main file similar to

UNITER/train_vcr.py

Lines 166 to 199 in 8b8e181

# Prepare model
if opts.checkpoint and opts.checkpoint_from == "pretrain":
checkpoint = torch.load(opts.checkpoint)
else:
checkpoint = {}
all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
for db in all_dbs)
model = UniterForVisualCommonsenseReasoning.from_pretrained(
opts.model_config, checkpoint, img_dim=IMG_DIM)
model.init_type_embedding()
model.init_word_embedding(NUM_SPECIAL_TOKENS)
if opts.checkpoint_from == "vcr_pretrain":
checkpoint = torch.load(opts.checkpoint)
state_dict = checkpoint.get('model_state', checkpoint)
matched_state_dict = {}
unexpected_keys = set()
missing_keys = set()
for name, param in model.named_parameters():
missing_keys.add(name)
for key, data in state_dict.items():
if key in missing_keys:
matched_state_dict[key] = data
missing_keys.remove(key)
else:
unexpected_keys.add(key)
print("Unexpected_keys:", list(unexpected_keys))
print("Missing_keys:", list(missing_keys))
model.load_state_dict(matched_state_dict, strict=False)
del checkpoint
model.to(device)
# make sure every process has same model parameters in the beginning

from uniter.

jaeyun95 avatar jaeyun95 commented on July 17, 2024

oh, i missed "init_type_embedding" and "init_word_embedding" function.

thank you for your kindly response!! :)

from uniter.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.