I just wanted to test your repository and noticed that the code fails for inference on CPU due to the grouped convolution.
model = NFNet(num_classes=1000, variant=variant)
model.build((None, 320, 320, 3))
model.load_weights(f"{variant}_NFNet/{variant}_NFNet")
test_image = tf.zeros(
shape=(1, 320, 320, 3), dtype=tf.float32
)
model(test_image)
I already started debugging and as far as I see, the error occurs in the Second Block in conv1
when calling WSConv2D()
. Here, the inputs are of shape (1, 68, 120, 256)
, while the weights are (3, 3, 128, 256)
.
I am not that familiar with grouped convolutions and NFNets in general. So I thought, you maybe already know how to solve the issue, if possible?