Comments (14)
cc @ylacombe as well
Also feel free to open a PR for a fix!
from transformers.
@ArthurZucker I've isolated the problem. Everything works now. But I still have a question, how to properly instantiate an untrained model?
At first I did it as follows:
model_name = "openai/whisper-tiny"
....
from transformers import AutoConfig, AutoModel
configuration = AutoConfig.from_pretrained(model_name)
model = AutoModel.from_config(configuration)
This is what was causing the problem.
How to do this correctly in the case of Whisper to train the model from zero?
from transformers.
@ArthurZucker I don't know, how to fix it(((
from transformers.
I have carefully studied the code, apparently the differences are only in one function. The code from the publication still runs, but the code for tuning the model into two languages does not.
Here is the only function that is different:
def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# get the language of our text
tokenizer.set_prefix_tokens(language=batch["language"], task="transcribe")
# encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
Honestly, I don't really see how adding token prefixes can break learning....
from transformers.
Hey @blademoon, you can get inspiration from the code snippet provided in the Whisper documentation:
from transformers import WhisperConfig, WhisperModel
# Initializing a Whisper tiny style configuration
configuration = WhisperConfig()
# Initializing a model (with random weights) from the tiny style configuration
model = WhisperModel(configuration)
# Accessing the model configuration
configuration = model.config
In your own case, you can first load the config from the repository id and then instantiate your model from the config:
from transformers import WhisperConfig, WhisperForConditionalGeneration
configuration = WhisperConfig.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration(configuration)
Hope it helps!
cc @eustlb for visibility
from transformers.
@ylacombe Your solution work. But another problem arises.
If we instantiate the model as you suggested:
from transformers import WhisperConfig, WhisperForConditionalGeneration
model_name = "openai/whisper-tiny"
configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
then configure the model like a notebook:
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
and train, an exception occurs:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[36], line 1
----> 1 trainer.train()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1926 try:
1927 # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
1928 hf_hub_utils.disable_progress_bars()
-> 1929 return inner_training_loop(
1930 args=args,
1931 resume_from_checkpoint=resume_from_checkpoint,
1932 trial=trial,
1933 ignore_keys_for_eval=ignore_keys_for_eval,
1934 )
1935 finally:
1936 hf_hub_utils.enable_progress_bars()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2353 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2354 self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2357 else:
2358 self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2802 metrics = None
2803 if self.control.should_evaluate:
-> 2804 metrics = self._evaluate(trial, ignore_keys_for_eval)
2806 if self.control.should_save:
2807 self._save_checkpoint(model, trial, metrics=metrics)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2762 self._report_to_hp_search(trial, self.state.global_step, metrics)
2764 # Run delayed LR scheduler now that metrics are populated
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
178 self.gather_function = self.accelerator.gather
179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
3663 start_time = time.time()
3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
3667 eval_dataloader,
3668 description="Evaluation",
3669 # No point gathering the predictions if there are no metrics, otherwise we defer to
3670 # self.args.prediction_loss_only
3671 prediction_loss_only=True if self.compute_metrics is None else None,
3672 ignore_keys=ignore_keys,
3673 metric_key_prefix=metric_key_prefix,
3674 )
3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
3854 batch_size = observed_batch_size
3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
302 if (
303 "labels" in generation_inputs
304 and "decoder_input_ids" in generation_inputs
305 and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
306 ):
307 generation_inputs = {
308 k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
309 }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
315 if self.model.generation_config._from_model_config:
File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
536 self._set_prompt_condition_type(
537 generation_config=generation_config,
538 prompt_condition_type=prompt_condition_type,
539 )
541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
543 input_features,
544 batch_size=batch_size,
545 generation_config=generation_config,
546 config=self.config,
547 num_segment_frames=num_segment_frames,
548 kwargs=kwargs,
549 )
550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
551 # where the input ids are handled explicitly by the generate method
552 self._check_decoder_input_ids(kwargs=kwargs)
File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
1355 if task is not None:
1356 if task in TASK_IDS:
-> 1357 init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1358 task_id = generation_config.task_to_id[generation_config.task]
1360 # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'
from transformers.
@ylacombe If I comment model.generation_config.forced_decoder_ids = None
:
from transformers import WhisperConfig, WhisperForConditionalGeneration
configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
model.generation_config.task = "transcribe"
# model.generation_config.forced_decoder_ids = None
I get another exception while training:
You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[16], line 1
----> 1 trainer.train()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1926 try:
1927 # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
1928 hf_hub_utils.disable_progress_bars()
-> 1929 return inner_training_loop(
1930 args=args,
1931 resume_from_checkpoint=resume_from_checkpoint,
1932 trial=trial,
1933 ignore_keys_for_eval=ignore_keys_for_eval,
1934 )
1935 finally:
1936 hf_hub_utils.enable_progress_bars()
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2353 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2354 self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2357 else:
2358 self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2802 metrics = None
2803 if self.control.should_evaluate:
-> 2804 metrics = self._evaluate(trial, ignore_keys_for_eval)
2806 if self.control.should_save:
2807 self._save_checkpoint(model, trial, metrics=metrics)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2762 self._report_to_hp_search(trial, self.state.global_step, metrics)
2764 # Run delayed LR scheduler now that metrics are populated
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
178 self.gather_function = self.accelerator.gather
179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
3663 start_time = time.time()
3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
3667 eval_dataloader,
3668 description="Evaluation",
3669 # No point gathering the predictions if there are no metrics, otherwise we defer to
3670 # self.args.prediction_loss_only
3671 prediction_loss_only=True if self.compute_metrics is None else None,
3672 ignore_keys=ignore_keys,
3673 metric_key_prefix=metric_key_prefix,
3674 )
3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
3854 batch_size = observed_batch_size
3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
302 if (
303 "labels" in generation_inputs
304 and "decoder_input_ids" in generation_inputs
305 and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
306 ):
307 generation_inputs = {
308 k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
309 }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
315 if self.model.generation_config._from_model_config:
File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
536 self._set_prompt_condition_type(
537 generation_config=generation_config,
538 prompt_condition_type=prompt_condition_type,
539 )
541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
543 input_features,
544 batch_size=batch_size,
545 generation_config=generation_config,
546 config=self.config,
547 num_segment_frames=num_segment_frames,
548 kwargs=kwargs,
549 )
550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
551 # where the input ids are handled explicitly by the generate method
552 self._check_decoder_input_ids(kwargs=kwargs)
File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
1355 if task is not None:
1356 if task in TASK_IDS:
-> 1357 init_tokens[i].append(generation_config.task_to_id[generation_config.task])
1358 task_id = generation_config.task_to_id[generation_config.task]
1360 # if task is defined it'll overwrite task ids that might have already been defined via the generation_config
AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'
from transformers.
Hey @blademoon,
It's true that, by default, the generation config of a transformers
model doesn't fit the Whisper generation config. I think you might want to start from the Whisper generation config:
from transformers import WhisperConfig, WhisperForConditionalGeneration
from transformers.generation.configuration_utils import GenerationConfig
model_name = "openai/whisper-tiny"
configuration = WhisperConfig.from_pretrained(model_name)
generation_config = GenerationConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
model.generation_config = generation_config
What you might want to do as well is to look at the tiny model generation config to see if the config fits your need, and to modify some parameters in the code above if necessary !
from transformers.
@ylacombe Hello. I'm testing the variant you suggested.
A small question, if the model with random weights is trained in two languages at once, then the proposed variant will work correctly too?
from transformers.
Most probably, but you might want to guide the model a bit more by using language tokens
from transformers.
@ylacombe I use this guide from @sanchit-gandhi .
Now i have same warning but training work:
from transformers.
@ylacombe It may be easier to understand if you can see the big picture.
from transformers.
If the training does what you want (i.e transcribing Russian and English I guess ?) then you can ignore the warning!
from transformers.
@ylacombe OK. I'll check it out and come back with feedback. Thank you.
from transformers.
Related Issues (20)
- Fast vs. slow tokenizer mismatch for special tokens with lstrip/rstrip set to true
- Any plans on adding Flash Attention 3? HOT 1
- Track progress for VLMs refactoring HOT 1
- Weird text encoder NaNs specifically for FSDP + multi GPU HOT 1
- The _crop_past_key_values function should be a member function of Cache.
- OOM for fine-tuning vlm llava-next-110B with QLoRA on 8 A100 GPUs HOT 8
- how can calculate the predict score of every pixel use mask2former swin-l model? HOT 3
- Zero-shot classification pipeline does not support bfloat16 HOT 3
- Using P-tuning On Whisper‘s decoder HOT 2
- Support AdEMAMix optimizer
- A Trainer subclass for Decoder-Only LM with generation in evaluate() HOT 3
- the problem of precision HOT 4
- GenerationConfig is not handled correctly when saving multi-task models HOT 1
- The same situation as #31377 occurred when using Qwen/Qwen2-VL-7B-Instruct HOT 9
- Encounter error when loading checkpoint generated by latest accelerate>=0.34.0 HOT 2
- [BUG] Latest version cannot load Qwen2-VL model config correctly. HOT 5
- Bug: The elements of the batch contain different keys. Cannot batch them ... HOT 2
- ValueError: Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation HOT 1
- TypeError: '<' not supported between instances of 'NoneType' and 'int' HOT 2
- Can’t train Mamba2 with FP16 (Mamba(/2)ForCausalLM) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.