When I try to run the accelerate training script, I am facing with a broadcast error.
08/19/2023 05:24:35 - INFO - __main__ - ***** Running training *****
08/19/2023 05:24:35 - INFO - __main__ - Num examples = 20072
08/19/2023 05:24:35 - INFO - __main__ - Num Epochs = 12
08/19/2023 05:24:35 - INFO - __main__ - Instantaneous batch size per device = 1
08/19/2023 05:24:35 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 16
08/19/2023 05:24:35 - INFO - __main__ - Gradient Accumulation steps = 4
08/19/2023 05:24:35 - INFO - __main__ - Total optimization steps = 15000
Steps: 0%| | 0/15000 [00:01<?, ?it/s, lr=1e-5, step_loss=44.7]Traceback (most recent call last):
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1208, in <module>
Traceback (most recent call last):
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1208, in <module>
Traceback (most recent call last):
Traceback (most recent call last):
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1208, in <module>
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1208, in <module>
main()
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1100, in main
main()
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1100, in main
main()
main() File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1100, in main
ema_unet.step(unet.parameters())
File "/home/shkulkarni/segmind/distill-sd/distill_training.py", line 1100, in main
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/diffusers/training_utils.py", line 194, in step
ema_unet.step(unet.parameters())
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
s_param.sub_(one_minus_decay * (s_param - param))
RuntimeError: output with shape [320, 320, 1, 1] doesn't match the broadcast shape [320, 320, 3, 3]
return func(*args, **kwargs)
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/diffusers/training_utils.py", line 194, in step
ema_unet.step(unet.parameters())ema_unet.step(unet.parameters())
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
s_param.sub_(one_minus_decay * (s_param - param))
RuntimeError: output with shape [320, 320, 1, 1] doesn't match the broadcast shape [320, 320, 3, 3]
return func(*args, **kwargs)
return func(*args, **kwargs)
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/diffusers/training_utils.py", line 194, in step
File "/opt/conda/envs/segmind/lib/python3.10/site-packages/diffusers/training_utils.py", line 194, in step
s_param.sub_(one_minus_decay * (s_param - param))s_param.sub_(one_minus_decay * (s_param - param))
RuntimeErrorRuntimeError: : output with shape [320, 320, 1, 1] doesn't match the broadcast shape [320, 320, 3, 3]output with shape [320, 320, 1, 1] doesn't match the broadcast shape [320, 320, 3, 3]
Thanks in advance.