Comments (22)
@mathigatti Here is a working code, Mathias...
https://github.com/asigalov61/tegridy-tools/blob/main/tegridy-tools/minGPT.py
Also, do not forget to save state_dic cuz torch seem to need it...
torch.save(model.state_dict(), full_path_to_save_model_to)
And if you have checkpoints w/o it, you can fix it by loading the model with torch.load(path) and then re-saving it with the above torch.save code. This worked for me.
Btw, please check out my profile and repos. If you are into AI and Music, I think you will enjoy it :)
from mingpt.
@mathigatti Sorry, here is the right colab link. The first one is the python version... https://colab.research.google.com/drive/1JeY5Lr41wnM2eL0WvaCL29bxGamPHDJb?usp=sharing
And thank you for complimenting my GitHub repos. It means a lot to me :)
from mingpt.
@mathigatti Glad to be of help! I am happy to know you figured it out :)
Yes, check out my repos as all of them have samples + models and datasets as well.
My Efficient/Intelligent Virtuoso is the best performing work of mine so definitely see it here: https://github.com/asigalov61/Intelligent-VIRTUOSO/tree/main/Samples
If you want more samples, I have a Soundcloud too:
https://soundcloud.com/aleksandr-sigalov-61/
AITEXTGEN did not show any good results as it is not a very good implementation of GPT2. In fact, minGPT is the best performing implementation for Music AI that I have used and tried. So definitely consider trying minGPT with music. And if you want to try the latest stuff with NLP and text you can use my TMIDI module and any of the available SOTA implementations like Reformer or Performer or Routing Transformer.
My TMIDI module is basically an advanced version of midi2text or abc2text similar codes.
https://github.com/asigalov61/tegridy-tools/blob/main/tegridy-tools/TMIDI.py
Hope this is useful :)
from mingpt.
Lines 53 to 58 in 5433de6
from mingpt.
Yes, I saw that but still I would like a simple example script.py like the Jupyter Notebook example that includes the loading and saving of the last checkpoint or, like with Char-RNN, the one I point to. I'm not the best at understanding Pytorch code, sorry.
from mingpt.
@aletote it's used in the training loop for saving checkpoints. You can see this usage in trainer.py:
Lines 123 to 129 in 8909e1b
from mingpt.
Yes, I saw that but still I would like a simple example script.py like the Jupyter Notebook example that includes the loading and saving of the last checkpoint or, like with Char-RNN, the one I point to. I'm not the best at understanding Pytorch code, sorry.
https://github.com/evantahler/Dont-be-a-Jerk
from mingpt.
Set path in trainer.py:
ckpt_path = '/content/drive/MyDrive/ck/ck'
Create the function load_checkpoint inside Class Trainer
def load_checkpoint(self):
if self.config.ckpt_path is not None:
ckpt_model = self.model.module if hasattr(self.model, "module") else self.model
logger.info("loading %s", self.config.ckpt_path)
ck = torch.load(self.config.ckpt_path)
ckpt_model.load_state_dict(ck)
and finally call it
self.load_checkpoint()
for epoch in range(config.max_epochs):
from mingpt.
I'm trying it in Google Colab, It doesn't seem to be loading the model successfully, anyone experiencing a similar problem?
After training the model a little bit
from mingpt.trainer import Trainer, TrainerConfig
# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=10000, batch_size=124, learning_rate=6e-4,
lr_decay=True, warmup_tokens=512*20, final_tokens=200*len(train_dataset)*block_size,
ckpt_path='test.pt',
num_workers=4)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()
If I restart the kernel and I try loading the model again like this
checkpoint = torch.load('test.pt')
model.load_state_dict(checkpoint)
Then the predictions are completely random and the loss is as high as it was at the beginning of the training.
from mingpt.
Amazing! Thanks, I will test it, anyway checking it quickly I don't see a difference between your save_checkpoint
implementation and the one from this repository. Am I missing something?
By the way awesome github projects :)
from mingpt.
oh okay, I saw you used a deepcopy while loading the checkpoint, maybe that helps, I will try that.
from mingpt.
:'( It's still failing to load the model successfully, I created this simple colab to replicate my problem.
from mingpt.
@mathigatti Apologies for the delayed response...
I think that maybe you are not loading the checkpoint properly?
The trick here is to match the saved checkpoint to the right loading code. In other words, if you saved the state_dic you need to load it properly as torch.load is not going to work here from my experience.
For example,
To save the model/checkpoint, i use:
torch.save(model, full_path_to_save_model_to)
OR
torch.save(model.state_dict(), full_path_to_save_model_to)
And only then you can load it by the code below:
model = torch.load(full_path_to_model_checkpoint)
model.eval()
Also, from my experience with minGPT, you need to initialize all training functions prior to loading the saved checkpoint. I am not sure why it is so, but this is how it works on my end.
Here is the example colab for you to check out with my stand-alone minGPT module that works fine AFAIk.
https://github.com/asigalov61/tegridy-tools/blob/main/Examples/efficient_virtuoso.py
from mingpt.
It's weird, I don't see any difference with this colab, it saves the model while training like this. And it loads the model as you suggested.
from mingpt.
yeah, if you are saving state_dic it will not load using torch.load with minGPT from my experience. You need to load it with state_dic option as torch docs suggest.
Also, make sure you are passing the checkpoint to all functions and also, make sure you are loading the model correctly otherwise it will not work.
For example, in my stand-alone module, I use two options to load the checkpoint because checkpoint save will have a state_dic. Therefore, you need to load it through main traning functions and only then you can use it. If you will look at my Google Colab, you will see that there is a place for it at the main loader cell. Otherwise, you do not need a state_dic and you can use standard torch.load or save functions to do it.
I am sorry that I use my own version of minGPT to give examples but I do not use the official release myself anymore so I can only give examples with what I use and what works for me. So try my stand-alone module in colab. Maybe it will work for you better as opposed to using the official version. Just a suggestion here...
And as I said before, DO NOT FORGET to properly load the checkpoint or saved model as it will not work otherwise.
As far as I can tell from your colab, you are trying to load state_dic checkpoint with a regular loading function - this will not work. And model output in your colab appears garbled so either you have not trained the model properly or you have not loaded it properly.
Hope this is helpful and feel free to ask more questions as I would love to help :)
from mingpt.
Thank you very much! I tried your colab with my shakespeare txt and it loaded the model successfully :D
By the way have you uploaded the midi composition results somewhere? I tried a pretty similar thing with aitextgen package and these scripts: text2midi & midi2text
from mingpt.
Here is the colab modified a little bit for regular text generation
from mingpt.
Here is the colab modified a little bit for regular text generation
Hey mate, how would I switch the encoding to English into utf8?
from mingpt.
utf8 already supports english, are you experiencing issues if you just try to open the file with utf8?
from mingpt.
utf8 already supports english, are you experiencing issues if you just try to open the file with utf8?
Thank you, it works perfectly now! But it doesn't produce checkpoints only the model path?
Is there a way to retrain a checkpoint because it doesn't find one or completely ignores the checkpoint path that I specified?
from mingpt.
Not sure, the colab I shared used to allow training, saving the checkpoint and loading it again. I didn't try training again from that
from mingpt.
Not sure, the colab I shared used to allow training, saving the checkpoint and loading it again. I didn't try training again from that
Oh I see, is there any way that I can do this tho ?
from mingpt.
Related Issues (20)
- Stop words?
- how does this compare to aitextgen?
- Information leak in training procedure?
- Crashed Encoder possible data corruption
- Simplifying weigh decay checking doesn't work HOT 3
- About layer norm dimention parameter: HOT 1
- 生成圖片
- Question: does it support other utf-8 natual language? HOT 1
- Output of CausalSelfAttention HOT 1
- How can I run a trained model and can't run Test_ Hugging face_ Import. py HOT 1
- AssertionError when run generate.ipynb with default parameter HOT 4
- Should -1 marker (as special token) be counted in vocab_size? HOT 1
- What's the max output tokens this model supports? HOT 1
- what is the minimum hardware requirement to train
- which pytorch version should be used pls for windows OS only CPU use only for inference ?
- error line 200, in from_pretrained assert len(keys) == len(sd) HOT 7
- concatenate two BPE tokenizer
- Support for Multi-GPU Parallel Training in chargpt.py
- how to build a model and interact with it like chatgpt?
- Strange model behavior when taking the softmax in the wrong dimension
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mingpt.