Comments (3)
It's caused by swin_transformer_v2.py#L95 attn = keras.layers.Add()([attn, mask])
. It can be fixed using attn = attn + mask
instead:
# Use `attn = attn + mask` instead of `attn = keras.layers.Add()([attn, mask])`
from keras_cv_attention_models import swin_transformer_v2
mm = swin_transformer_v2.SwinTransformerV2Small_ns()
mm.save('aa.h5')
bb = keras.models.load_model('aa.h5')
But then it will throw error using TPU
with bfloat16
in model saving... Not sure if any method fit both situation.
Currently, you may reload using load_weights
:
from keras_cv_attention_models import swin_transformer_v2
mm = swin_transformer_v2.SwinTransformerV2Small_ns(input_shape=..., num_classes=...)
# Any other layers
mm.load_weights("{pretrained.h5}")
Maybe wrapping make_window_attention_mask
as a layer can work for both situation. Will take a try.
from keras_cv_attention_models.
It should works now, for both situation. May save and load your model again:
from keras_cv_attention_models import swin_transformer_v2
mm = swin_transformer_v2.SwinTransformerV2Small_ns(input_shape=..., num_classes=...)
# Any other layers
mm.load_weights("{pretrained.h5}")
mm.save("aa.h5")
bb = keras.models.load_model('aa.h5')
from keras_cv_attention_models.
It works now, amazing job, thanks!
from keras_cv_attention_models.
Related Issues (20)
- Convnextv2 custom pretraining HOT 1
- ModuleNotFoundError: No module named 'keras_cv_attention_models.maxvit' HOT 10
- Preprocessing for each network HOT 2
- UserWarning raised when I instantiate a model HOT 3
- Where to download the trained weight of DaViT_L, DaViT_H, DaViT_G model variation ? HOT 4
- DINO v2 backbone HOT 5
- Error loading coatnet model after the training process HOT 6
- EfficientViT-B0/B1/B2/B3 Models HOT 6
- eva02 fp16 not working HOT 4
- Training from scratch HOT 1
- training yolov8 with anchor-free anchor mode HOT 8
- assigning class weight to each class during training and access to the each class map and recall HOT 5
- EdgeNeXt-Base Model HOT 2
- EfficientVit-B0 ImageNet 224 Weights Released HOT 2
- Convnextv2 outputs nan when training in fp16 HOT 7
- Extraction of heatmaps from attention layers. HOT 1
- Super-resolution using transformes HOT 2
- RepViT Models HOT 2
- Efficientnet EdgeTpu Models HOT 1
- Training custom model HOT 10
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from keras_cv_attention_models.