Giter Club home page Giter Club logo

Comments (14)

ArthurZucker avatar ArthurZucker commented on September 12, 2024 1

cc @ylacombe as well
Also feel free to open a PR for a fix!

from transformers.

blademoon avatar blademoon commented on September 12, 2024 1

@ArthurZucker I've isolated the problem. Everything works now. But I still have a question, how to properly instantiate an untrained model?

At first I did it as follows:

model_name = "openai/whisper-tiny"
....
from transformers import AutoConfig, AutoModel

configuration = AutoConfig.from_pretrained(model_name)
model = AutoModel.from_config(configuration)

This is what was causing the problem.

How to do this correctly in the case of Whisper to train the model from zero?

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ArthurZucker I don't know, how to fix it(((

from transformers.

blademoon avatar blademoon commented on September 12, 2024

I have carefully studied the code, apparently the differences are only in one function. The code from the publication still runs, but the code for tuning the model into two languages does not.

Here is the only function that is different:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # get the language of our text
    tokenizer.set_prefix_tokens(language=batch["language"], task="transcribe") 
   
    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

Honestly, I don't really see how adding token prefixes can break learning....

from transformers.

ylacombe avatar ylacombe commented on September 12, 2024

Hey @blademoon, you can get inspiration from the code snippet provided in the Whisper documentation:

from transformers import WhisperConfig, WhisperModel

# Initializing a Whisper tiny style configuration
configuration = WhisperConfig()

# Initializing a model (with random weights) from the tiny style configuration
model = WhisperModel(configuration)

# Accessing the model configuration
configuration = model.config

In your own case, you can first load the config from the repository id and then instantiate your model from the config:

from transformers import WhisperConfig, WhisperForConditionalGeneration

configuration = WhisperConfig.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration(configuration)

Hope it helps!

cc @eustlb for visibility

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ylacombe Your solution work. But another problem arises.

If we instantiate the model as you suggested:

from transformers import WhisperConfig, WhisperForConditionalGeneration

model_name = "openai/whisper-tiny"

configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)

then configure the model like a notebook:

model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

and train, an exception occurs:

image

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[36], line 1
----> 1 trainer.train()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1926 try:
   1927     # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
   1928     hf_hub_utils.disable_progress_bars()
-> 1929     return inner_training_loop(
   1930         args=args,
   1931         resume_from_checkpoint=resume_from_checkpoint,
   1932         trial=trial,
   1933         ignore_keys_for_eval=ignore_keys_for_eval,
   1934     )
   1935 finally:
   1936     hf_hub_utils.enable_progress_bars()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2353     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2354     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2357 else:
   2358     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2802 metrics = None
   2803 if self.control.should_evaluate:
-> 2804     metrics = self._evaluate(trial, ignore_keys_for_eval)
   2806 if self.control.should_save:
   2807     self._save_checkpoint(model, trial, metrics=metrics)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2762     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2764     # Run delayed LR scheduler now that metrics are populated

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    178 self.gather_function = self.accelerator.gather
    179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3663 start_time = time.time()
   3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
   3667     eval_dataloader,
   3668     description="Evaluation",
   3669     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3670     # self.args.prediction_loss_only
   3671     prediction_loss_only=True if self.compute_metrics is None else None,
   3672     ignore_keys=ignore_keys,
   3673     metric_key_prefix=metric_key_prefix,
   3674 )
   3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3854         batch_size = observed_batch_size
   3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
   3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
    302 if (
    303     "labels" in generation_inputs
    304     and "decoder_input_ids" in generation_inputs
    305     and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
    306 ):
    307     generation_inputs = {
    308         k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
    309     }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
    312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
    313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
    314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
    315 if self.model.generation_config._from_model_config:

File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
    536 self._set_prompt_condition_type(
    537     generation_config=generation_config,
    538     prompt_condition_type=prompt_condition_type,
    539 )
    541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
    543     input_features,
    544     batch_size=batch_size,
    545     generation_config=generation_config,
    546     config=self.config,
    547     num_segment_frames=num_segment_frames,
    548     kwargs=kwargs,
    549 )
    550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
    551 # where the input ids are handled explicitly by the generate method
    552 self._check_decoder_input_ids(kwargs=kwargs)

File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8888/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
   1355 if task is not None:
   1356     if task in TASK_IDS:
-> 1357         init_tokens[i].append(generation_config.task_to_id[generation_config.task])
   1358         task_id = generation_config.task_to_id[generation_config.task]
   1360         # if task is defined it'll overwrite task ids that might have already been defined via the generation_config

AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ylacombe If I comment model.generation_config.forced_decoder_ids = None:

from transformers import WhisperConfig, WhisperForConditionalGeneration

configuration = WhisperConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)

model.generation_config.task = "transcribe"
# model.generation_config.forced_decoder_ids = None

I get another exception while training:
image

You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[16], line 1
----> 1 trainer.train()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:1929](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=1928), in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1926 try:
   1927     # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
   1928     hf_hub_utils.disable_progress_bars()
-> 1929     return inner_training_loop(
   1930         args=args,
   1931         resume_from_checkpoint=resume_from_checkpoint,
   1932         trial=trial,
   1933         ignore_keys_for_eval=ignore_keys_for_eval,
   1934     )
   1935 finally:
   1936     hf_hub_utils.enable_progress_bars()

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2356](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2355), in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2353     self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
   2354     self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2356     self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2357 else:
   2358     self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2804](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2803), in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2802 metrics = None
   2803 if self.control.should_evaluate:
-> 2804     metrics = self._evaluate(trial, ignore_keys_for_eval)
   2806 if self.control.should_save:
   2807     self._save_checkpoint(model, trial, metrics=metrics)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:2761](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=2760), in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2760 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2761     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2762     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2764     # Run delayed LR scheduler now that metrics are populated

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:180](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=179), in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
    178 self.gather_function = self.accelerator.gather
    179 self._gen_kwargs = gen_kwargs
--> 180 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3666](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3665), in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3663 start_time = time.time()
   3665 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3666 output = eval_loop(
   3667     eval_dataloader,
   3668     description="Evaluation",
   3669     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3670     # self.args.prediction_loss_only
   3671     prediction_loss_only=True if self.compute_metrics is None else None,
   3672     ignore_keys=ignore_keys,
   3673     metric_key_prefix=metric_key_prefix,
   3674 )
   3676 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3677 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File [~/.local/lib/python3.10/site-packages/transformers/trainer.py:3857](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer.py#line=3856), in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3854         batch_size = observed_batch_size
   3856 # Prediction step
-> 3857 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3858 main_input_name = getattr(self.model, "main_input_name", "input_ids")
   3859 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

File [~/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py:310](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/trainer_seq2seq.py#line=309), in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
    302 if (
    303     "labels" in generation_inputs
    304     and "decoder_input_ids" in generation_inputs
    305     and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
    306 ):
    307     generation_inputs = {
    308         k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
    309     }
--> 310 generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
    312 # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
    313 # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
    314 # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
    315 if self.model.generation_config._from_model_config:

File ~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:542, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, prompt_condition_type, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
    536 self._set_prompt_condition_type(
    537     generation_config=generation_config,
    538     prompt_condition_type=prompt_condition_type,
    539 )
    541 # pass self.config for backward compatibility
--> 542 init_tokens = self._retrieve_init_tokens(
    543     input_features,
    544     batch_size=batch_size,
    545     generation_config=generation_config,
    546     config=self.config,
    547     num_segment_frames=num_segment_frames,
    548     kwargs=kwargs,
    549 )
    550 # passing `decoder_input_ids` is deprecated - the only exception is for assisted generation
    551 # where the input ids are handled explicitly by the generate method
    552 self._check_decoder_input_ids(kwargs=kwargs)

File [~/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1357](http://127.0.0.1:8889/home/artyom/.local/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py#line=1356), in WhisperGenerationMixin._retrieve_init_tokens(self, input_features, batch_size, generation_config, config, num_segment_frames, kwargs)
   1355 if task is not None:
   1356     if task in TASK_IDS:
-> 1357         init_tokens[i].append(generation_config.task_to_id[generation_config.task])
   1358         task_id = generation_config.task_to_id[generation_config.task]
   1360         # if task is defined it'll overwrite task ids that might have already been defined via the generation_config

AttributeError: 'GenerationConfig' object has no attribute 'task_to_id'

from transformers.

ylacombe avatar ylacombe commented on September 12, 2024

Hey @blademoon,
It's true that, by default, the generation config of a transformers model doesn't fit the Whisper generation config. I think you might want to start from the Whisper generation config:

from transformers import WhisperConfig, WhisperForConditionalGeneration
from transformers.generation.configuration_utils import GenerationConfig


model_name = "openai/whisper-tiny"

configuration = WhisperConfig.from_pretrained(model_name)
generation_config = GenerationConfig.from_pretrained(model_name)
model = WhisperForConditionalGeneration(configuration)
model.generation_config = generation_config

What you might want to do as well is to look at the tiny model generation config to see if the config fits your need, and to modify some parameters in the code above if necessary !

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ylacombe Hello. I'm testing the variant you suggested.

A small question, if the model with random weights is trained in two languages at once, then the proposed variant will work correctly too?

from transformers.

ylacombe avatar ylacombe commented on September 12, 2024

Most probably, but you might want to guide the model a bit more by using language tokens

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ylacombe I use this guide from @sanchit-gandhi .

Now i have same warning but training work:

image

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ylacombe It may be easier to understand if you can see the big picture.

from transformers.

ylacombe avatar ylacombe commented on September 12, 2024

If the training does what you want (i.e transcribing Russian and English I guess ?) then you can ignore the warning!

from transformers.

blademoon avatar blademoon commented on September 12, 2024

@ylacombe OK. I'll check it out and come back with feedback. Thank you.

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.