Giter Club home page Giter Club logo

Comments (22)

asigalov61 avatar asigalov61 commented on May 13, 2024 2

@mathigatti Here is a working code, Mathias...
https://github.com/asigalov61/tegridy-tools/blob/main/tegridy-tools/minGPT.py

Also, do not forget to save state_dic cuz torch seem to need it...
torch.save(model.state_dict(), full_path_to_save_model_to)

And if you have checkpoints w/o it, you can fix it by loading the model with torch.load(path) and then re-saving it with the above torch.save code. This worked for me.

Btw, please check out my profile and repos. If you are into AI and Music, I think you will enjoy it :)

from mingpt.

asigalov61 avatar asigalov61 commented on May 13, 2024 2

@mathigatti Sorry, here is the right colab link. The first one is the python version... https://colab.research.google.com/drive/1JeY5Lr41wnM2eL0WvaCL29bxGamPHDJb?usp=sharing

And thank you for complimenting my GitHub repos. It means a lot to me :)

from mingpt.

asigalov61 avatar asigalov61 commented on May 13, 2024 1

@mathigatti Glad to be of help! I am happy to know you figured it out :)

Yes, check out my repos as all of them have samples + models and datasets as well.
My Efficient/Intelligent Virtuoso is the best performing work of mine so definitely see it here: https://github.com/asigalov61/Intelligent-VIRTUOSO/tree/main/Samples

If you want more samples, I have a Soundcloud too:
https://soundcloud.com/aleksandr-sigalov-61/

AITEXTGEN did not show any good results as it is not a very good implementation of GPT2. In fact, minGPT is the best performing implementation for Music AI that I have used and tried. So definitely consider trying minGPT with music. And if you want to try the latest stuff with NLP and text you can use my TMIDI module and any of the available SOTA implementations like Reformer or Performer or Routing Transformer.

My TMIDI module is basically an advanced version of midi2text or abc2text similar codes.
https://github.com/asigalov61/tegridy-tools/blob/main/tegridy-tools/TMIDI.py

Hope this is useful :)

from mingpt.

glenn-jocher avatar glenn-jocher commented on May 13, 2024

minGPT/mingpt/trainer.py

Lines 53 to 58 in 5433de6

def save_checkpoint(self):
if self.config.ckpt_path is not None:
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("saving %s", self.config.ckpt_path)
torch.save(ckpt_model.state_dict(), self.config.ckpt_path)

from mingpt.

aletote avatar aletote commented on May 13, 2024

Yes, I saw that but still I would like a simple example script.py like the Jupyter Notebook example that includes the loading and saving of the last checkpoint or, like with Char-RNN, the one I point to. I'm not the best at understanding Pytorch code, sorry.

from mingpt.

glenn-jocher avatar glenn-jocher commented on May 13, 2024

@aletote it's used in the training loop for saving checkpoints. You can see this usage in trainer.py:

minGPT/mingpt/trainer.py

Lines 123 to 129 in 8909e1b

for epoch in range(config.max_epochs):
run_epoch('train')
if self.test_dataset is not None:
run_epoch('test')
self.save_checkpoint()

from mingpt.

aletote avatar aletote commented on May 13, 2024

Yes, I saw that but still I would like a simple example script.py like the Jupyter Notebook example that includes the loading and saving of the last checkpoint or, like with Char-RNN, the one I point to. I'm not the best at understanding Pytorch code, sorry.

https://github.com/davidbau/how-to-read-pytorch

https://github.com/evantahler/Dont-be-a-Jerk

from mingpt.

fabiovila avatar fabiovila commented on May 13, 2024

Set path in trainer.py:

ckpt_path = '/content/drive/MyDrive/ck/ck'

Create the function load_checkpoint inside Class Trainer

   def load_checkpoint(self):
        if self.config.ckpt_path is not None:
          ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
          logger.info("loading %s", self.config.ckpt_path)
          ck = torch.load(self.config.ckpt_path)
          ckpt_model.load_state_dict(ck) 

and finally call it

        self.load_checkpoint()
        for epoch in range(config.max_epochs):

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

I'm trying it in Google Colab, It doesn't seem to be loading the model successfully, anyone experiencing a similar problem?

After training the model a little bit

from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=10000, batch_size=124, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(train_dataset)*block_size,
                      ckpt_path='test.pt',
                      num_workers=4)

trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

If I restart the kernel and I try loading the model again like this

checkpoint = torch.load('test.pt')
model.load_state_dict(checkpoint)

Then the predictions are completely random and the loss is as high as it was at the beginning of the training.

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

Amazing! Thanks, I will test it, anyway checking it quickly I don't see a difference between your save_checkpoint implementation and the one from this repository. Am I missing something?

By the way awesome github projects :)

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

oh okay, I saw you used a deepcopy while loading the checkpoint, maybe that helps, I will try that.

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

:'( It's still failing to load the model successfully, I created this simple colab to replicate my problem.

from mingpt.

asigalov61 avatar asigalov61 commented on May 13, 2024

@mathigatti Apologies for the delayed response...

I think that maybe you are not loading the checkpoint properly?

The trick here is to match the saved checkpoint to the right loading code. In other words, if you saved the state_dic you need to load it properly as torch.load is not going to work here from my experience.

For example,

To save the model/checkpoint, i use:

torch.save(model, full_path_to_save_model_to)

OR

torch.save(model.state_dict(), full_path_to_save_model_to)

And only then you can load it by the code below:

model = torch.load(full_path_to_model_checkpoint)
model.eval()

Also, from my experience with minGPT, you need to initialize all training functions prior to loading the saved checkpoint. I am not sure why it is so, but this is how it works on my end.

Here is the example colab for you to check out with my stand-alone minGPT module that works fine AFAIk.
https://github.com/asigalov61/tegridy-tools/blob/main/Examples/efficient_virtuoso.py

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

It's weird, I don't see any difference with this colab, it saves the model while training like this. And it loads the model as you suggested.

from mingpt.

asigalov61 avatar asigalov61 commented on May 13, 2024

yeah, if you are saving state_dic it will not load using torch.load with minGPT from my experience. You need to load it with state_dic option as torch docs suggest.

Also, make sure you are passing the checkpoint to all functions and also, make sure you are loading the model correctly otherwise it will not work.

For example, in my stand-alone module, I use two options to load the checkpoint because checkpoint save will have a state_dic. Therefore, you need to load it through main traning functions and only then you can use it. If you will look at my Google Colab, you will see that there is a place for it at the main loader cell. Otherwise, you do not need a state_dic and you can use standard torch.load or save functions to do it.

I am sorry that I use my own version of minGPT to give examples but I do not use the official release myself anymore so I can only give examples with what I use and what works for me. So try my stand-alone module in colab. Maybe it will work for you better as opposed to using the official version. Just a suggestion here...

And as I said before, DO NOT FORGET to properly load the checkpoint or saved model as it will not work otherwise.

As far as I can tell from your colab, you are trying to load state_dic checkpoint with a regular loading function - this will not work. And model output in your colab appears garbled so either you have not trained the model properly or you have not loaded it properly.

Hope this is helpful and feel free to ask more questions as I would love to help :)

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

Thank you very much! I tried your colab with my shakespeare txt and it loaded the model successfully :D

By the way have you uploaded the midi composition results somewhere? I tried a pretty similar thing with aitextgen package and these scripts: text2midi & midi2text

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

Here is the colab modified a little bit for regular text generation

from mingpt.

Marcus-Arcadius avatar Marcus-Arcadius commented on May 13, 2024

Here is the colab modified a little bit for regular text generation

Hey mate, how would I switch the encoding to English into utf8?
image

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

utf8 already supports english, are you experiencing issues if you just try to open the file with utf8?

from mingpt.

Marcus-Arcadius avatar Marcus-Arcadius commented on May 13, 2024

utf8 already supports english, are you experiencing issues if you just try to open the file with utf8?

Thank you, it works perfectly now! But it doesn't produce checkpoints only the model path?
Is there a way to retrain a checkpoint because it doesn't find one or completely ignores the checkpoint path that I specified?

from mingpt.

mathigatti avatar mathigatti commented on May 13, 2024

Not sure, the colab I shared used to allow training, saving the checkpoint and loading it again. I didn't try training again from that

from mingpt.

Marcus-Arcadius avatar Marcus-Arcadius commented on May 13, 2024

Not sure, the colab I shared used to allow training, saving the checkpoint and loading it again. I didn't try training again from that

Oh I see, is there any way that I can do this tho ?

from mingpt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.