Giter Club home page Giter Club logo

Comments (7)

mingxingtan avatar mingxingtan commented on May 18, 2024 3

Aha, I see. I just submitted a change c470de8 that allows you to exclude some variables.

For example, I have verified this command line works:

python main.py --training_file_pattern=/coco_tfrecord/train-00000-of-00256.tfrecord \
    --val_json_file=/coco_tfrecord//annotations/instances_val2017.json \
    --model_name=efficientdet-d0 \
    --model_dir=/tmp/test/ \
    --ckpt=/ckpt/efficientdet/efficientdet-d0 \
    --hparams="use_bfloat16=false,num_classes=10,var_exclude_expr=r'.*/class-predict/.*'" \
    --use_tpu=False

from automl.

magi-toneu avatar magi-toneu commented on May 18, 2024 1

I am also interested in finetunning from an efficientDet checkpoint but using a different number of classes.

from automl.

ancorasir avatar ancorasir commented on May 18, 2024

Have the same question. I'd like to finetune on my own dataset which has different number of classes. Instead of starting from random initialization, using the pretrained EfficientDet might be a better choice.

from automl.

b03505036 avatar b03505036 commented on May 18, 2024

Watching the issue, too.

from automl.

mingxingtan avatar mingxingtan commented on May 18, 2024

@Ely-S helped add this finetune support in #49, you should be able to finetune by adding "--ckpt=xxx" instead of '--backbone_ckpt=yy".

from automl.

mingxingtan avatar mingxingtan commented on May 18, 2024

A simple example on finetuning model with a single shard of coco train data:

MODEL=efficientdet-d0
python main.py --training_file_pattern=/coco_tfrecord/train-00000-of-00256.tfrecord \
    --val_json_file=/coco_tfrecord//annotations/instances_val2017.json \
    --model_name=efficientdet-d0 \
    --model_dir=/tmp/test/ \
    --ckpt=/ckpt/efficientdet/efficientdet-d0 \
    --hparams="use_bfloat16=false" --use_tpu=False

It would load weights from coco-pretrained /ckpt/efficientdet/efficientdet-d0 (skip those momentum varriables), and then start training from this initialization.

from automl.

TomHeaven avatar TomHeaven commented on May 18, 2024

@Ely-S helped add this finetune support in #49, you should be able to finetune by adding "--ckpt=xxx" instead of '--backbone_ckpt=yy".

This problem is when we have a number of classes different from 1000, there is currently no way to load checkpoints successfully.

For example, when I train with

MODEL=efficientdet-d0
python3 ../main.py --mode=train_and_eval --training_file_pattern=../../../data/tfrecord_train/train* \
        --validation_file_pattern=../../../data/tfrecord_val/val*  \
        --val_json_file=../../../data/train/val.json \
        --model_name=$MODEL \
        --backbone_ckpt=../../pretrained_weights/$MODEL \
        --model_dir=../../weights/$MODEL \
        --train_batch_size=8  \
        --num_examples_per_epoch=5543 \
        --num_epochs=12 \
        --hparams="use_bfloat16=false,num_classes=5,learning_rate=0.0025" --use_tpu=False

I will get an error as

File "../main.py", line 396, in <module>
    tf.app.run(main)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/usr/local/lib/python3.6/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "../main.py", line 364, in main
    steps=int(FLAGS.num_examples_per_epoch / FLAGS.train_batch_size))
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3035, in train
    rendezvous.raise_errors()
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py", line 143, in raise_errors
    six.reraise(typ, value, traceback)
  File "/usr/local/lib/python3.6/site-packages/six.py", line 703, in reraise
    raise value
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3030, in train
    saving_listeners=saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 374, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1164, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1198, in _train_model_default
    saving_listeners)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1493, in _train_with_estimator_spec
    log_step_count_steps=log_step_count_steps) as mon_sess:
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 604, in MonitoredTrainingSession
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 1038, in __init__
    stop_grace_period_secs=stop_grace_period_secs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 749, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 1231, in __init__
    _WrappedSession.__init__(self, self._create_session())
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 1236, in _create_session
    return self._sess_creator.create_session()
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 902, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 669, in create_session
    init_fn=self._scaffold.init_fn)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/session_manager.py", line 294, in prepare_session
    config=config)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/session_manager.py", line 224, in _restore_checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1326, in restore
    err, "a mismatch between the current graph and the graph")
tensorflow.python.framework.errors_impl.InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [1,1,64,45] rhs shape= [1,1,64,810]
         [[node save/Assign_371 (defined at /usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py:1493) ]]

Besides, change --backbone_ckpt to --ckpt will result in a similar error. I know there are weight shapes that does not match due to change in number of classes, but I cannot find docs on how to fix the problem.

from automl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.