Comments (7)
Aha, I see. I just submitted a change c470de8 that allows you to exclude some variables.
For example, I have verified this command line works:
python main.py --training_file_pattern=/coco_tfrecord/train-00000-of-00256.tfrecord \
--val_json_file=/coco_tfrecord//annotations/instances_val2017.json \
--model_name=efficientdet-d0 \
--model_dir=/tmp/test/ \
--ckpt=/ckpt/efficientdet/efficientdet-d0 \
--hparams="use_bfloat16=false,num_classes=10,var_exclude_expr=r'.*/class-predict/.*'" \
--use_tpu=False
from automl.
I am also interested in finetunning from an efficientDet checkpoint but using a different number of classes.
from automl.
Have the same question. I'd like to finetune on my own dataset which has different number of classes. Instead of starting from random initialization, using the pretrained EfficientDet might be a better choice.
from automl.
Watching the issue, too.
from automl.
@Ely-S helped add this finetune support in #49, you should be able to finetune by adding "--ckpt=xxx" instead of '--backbone_ckpt=yy".
from automl.
A simple example on finetuning model with a single shard of coco train data:
MODEL=efficientdet-d0
python main.py --training_file_pattern=/coco_tfrecord/train-00000-of-00256.tfrecord \
--val_json_file=/coco_tfrecord//annotations/instances_val2017.json \
--model_name=efficientdet-d0 \
--model_dir=/tmp/test/ \
--ckpt=/ckpt/efficientdet/efficientdet-d0 \
--hparams="use_bfloat16=false" --use_tpu=False
It would load weights from coco-pretrained /ckpt/efficientdet/efficientdet-d0 (skip those momentum varriables), and then start training from this initialization.
from automl.
@Ely-S helped add this finetune support in #49, you should be able to finetune by adding "--ckpt=xxx" instead of '--backbone_ckpt=yy".
This problem is when we have a number of classes different from 1000, there is currently no way to load checkpoints successfully.
For example, when I train with
MODEL=efficientdet-d0
python3 ../main.py --mode=train_and_eval --training_file_pattern=../../../data/tfrecord_train/train* \
--validation_file_pattern=../../../data/tfrecord_val/val* \
--val_json_file=../../../data/train/val.json \
--model_name=$MODEL \
--backbone_ckpt=../../pretrained_weights/$MODEL \
--model_dir=../../weights/$MODEL \
--train_batch_size=8 \
--num_examples_per_epoch=5543 \
--num_epochs=12 \
--hparams="use_bfloat16=false,num_classes=5,learning_rate=0.0025" --use_tpu=False
I will get an error as
File "../main.py", line 396, in <module>
tf.app.run(main)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/platform/app.py", line 40, in run
_run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
File "/usr/local/lib/python3.6/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/usr/local/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "../main.py", line 364, in main
steps=int(FLAGS.num_examples_per_epoch / FLAGS.train_batch_size))
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3035, in train
rendezvous.raise_errors()
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/error_handling.py", line 143, in raise_errors
six.reraise(typ, value, traceback)
File "/usr/local/lib/python3.6/site-packages/six.py", line 703, in reraise
raise value
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/tpu/tpu_estimator.py", line 3030, in train
saving_listeners=saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 374, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1164, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1198, in _train_model_default
saving_listeners)
File "/usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1493, in _train_with_estimator_spec
log_step_count_steps=log_step_count_steps) as mon_sess:
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 604, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 1038, in __init__
stop_grace_period_secs=stop_grace_period_secs)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 749, in __init__
self._sess = _RecoverableSession(self._coordinated_creator)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 1231, in __init__
_WrappedSession.__init__(self, self._create_session())
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 1236, in _create_session
return self._sess_creator.create_session()
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 902, in create_session
self.tf_sess = self._session_creator.create_session()
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/monitored_session.py", line 669, in create_session
init_fn=self._scaffold.init_fn)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/session_manager.py", line 294, in prepare_session
config=config)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/session_manager.py", line 224, in _restore_checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py", line 1326, in restore
err, "a mismatch between the current graph and the graph")
tensorflow.python.framework.errors_impl.InvalidArgumentError: Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Assign requires shapes of both tensors to match. lhs shape= [1,1,64,45] rhs shape= [1,1,64,810]
[[node save/Assign_371 (defined at /usr/local/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py:1493) ]]
Besides, change --backbone_ckpt
to --ckpt
will result in a similar error. I know there are weight shapes that does not match due to change in number of classes, but I cannot find docs on how to fix the problem.
from automl.
Related Issues (20)
- the code of the BIFPN HOT 1
- Lion optimizer : module 'tensorflow.keras.optimizers' has no attribute 'legacy' HOT 1
- UserWarning: __floordiv__ is deprecated && assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
- More inplace ops for pytorch lion's impl
- ERROR : 'ImageFont' object has no attribute 'getbbox' HOT 2
- Potentially wrong type inference
- How to apply quantization aware training on EfficientDet keras model?
- How to train ViT image classification model on our dataset using LION optimizer
- how to train model by lion optimizer with fp16? HOT 1
- how to fix (terminate called after throwing an instance of 'std::bad_alloc' what(): std::bad_alloc --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[48], line 3 1 #!rm summary.h5 2 #!rm statepoint.*.h5 ----> 3 sp_filename = model.run() 5 sp = openmc.StatePoint(sp_filename)?
- Error during prediction within coreML framework of the converted Efficientdet-lite0 model
- why the text label is not showing on the bounding box HOT 1
- Question about Lion HOT 1
- TypeError: The `filenames` argument must contain `tf.string` elements. Got `tf.float32` elements error HOT 1
- buffer_size must be greater than zero error when use custom dataset HOT 1
- p.add_(..., inplace=True) error HOT 1
- efficientnetv2-bn parameters for progressive learning
- How to add class weights?
- Error reading original efficientdet-d3_frozen.pb on openCV`s readNetFromTensorflow HOT 2
- EfficienDet output format question
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from automl.