Hi ! Great work, I have a trouble in loading the state_dict for model, details as follow:
Traceback (most recent call last):
File "/code/prompt/RSPrompter-cky/./tools/predict.py", line 49, in
main()
File "/code/prompt/RSPrompter-cky/./tools/predict.py", line 45, in main
runner.run(args.status, ckpt_path=args.ckpt_path)
File "/code/prompt/RSPrompter-cky/tools/../mmpl/engine/runner/pl_runner.py", line 323, in run
return trainer_func(model=self.model, datamodule=self.datamodule, *args, **kwargs)
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 805, in predict
return call._call_and_handle_interrupt(
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 847, in _predict_impl
results = self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 901, in _run
self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 396, in _restore_modules_and_callbacks
self.restore_model()
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 278, in restore_model
trainer.strategy.load_model_state_dict(self._loaded_checkpoint)
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 351, in load_model_state_dict
self.lightning_module.load_state_dict(checkpoint["state_dict"])
File "/opt/conda/envs/RSPrompter/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SegSAMAnchorPLer:
size mismatch for panoptic_head.roi_head.mask_head.point_emb.8.weight: copying a param with shape torch.Size([2560, 256]) from checkpoint, the shape in current model is torch.Size([2048, 256]).
size mismatch for panoptic_head.roi_head.mask_head.point_emb.8.bias: copying a param with shape torch.Size([2560]) from checkpoint, the shape in current model is torch.Size([2048]).
config file predict_rsprompter_anchor_nwpu.py and mask_head config:
mask_head=dict(
type='SAMPromptMaskHead',
per_query_point=prompt_shape[1],
with_sincos=True,
class_agnostic=True,
loss_mask=dict(
type='mmdet.CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
prompt_shape = [60, 4], pth file NWPU_anchor.pth from HungFace. How can I fix this?