The pytorch Optimizer class has changed with recent releases, which leads to the following error:
[...]
stepping every 16 training passes, cycling lr every 1 epochs
checkin at 2 epochs to match lr scheduler
Traceback (most recent call last):
File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 141, in <module>
run_cv(X, y, f'eval-{task}-{target}.txt', n_splits)
File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 95, in run_cv
model = train_model()
File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 62, in train_model
model.fit(epochs=1000, losscurve=False)
File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/model.py", line 228, in fit
self.train()
File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/model.py", line 140, in train
self.optimizer.step()
File "/home/pbenner/.local/opt/anaconda3/envs/crysfeat/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 69, in wrapper
return wrapped(*args, **kwargs)
File "/home/pbenner/.local/opt/anaconda3/envs/crysfeat/lib/python3.10/site-packages/torch/optim/optimizer.py", line 271, in wrapper
for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'. Did you mean: '_optimizer_step_code'?
diff --git a/utils/optim.py b/utils/optim.py
index 33008dd..18224ea 100644
--- a/utils/optim.py
+++ b/utils/optim.py
@@ -1,6 +1,7 @@
-from collections import defaultdict
+from collections import defaultdict, OrderedDict
from itertools import chain
from torch.optim import Optimizer
+from typing import Callable, Dict
import torch
import warnings
import numpy as np
@@ -116,6 +117,8 @@ class SWA(Optimizer):
self.optimizer = optimizer
self.defaults = self.optimizer.defaults
+ self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
+ self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.opt_state = self.optimizer.state