import tensorflow as tf
from tensorflow import keras
from keras_cv_attention_models.common_layers import (
layer_norm, activation_by_name
)
from tensorflow.keras import initializers
from keras_cv_attention_models.attention_layers import (
conv2d_no_bias,
drop_block,
)
import math
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5
TF_BATCH_NORM_EPSILON = 0.001
LAYER_NORM_EPSILON = 1e-5
@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class PositionalEncodingFourier(keras.layers.Layer):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super(PositionalEncodingFourier, self).__init__()
self.token_projection = tf.keras.layers.Conv2D(dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
self.eps = 1e-6
def __call__(self, B, H, W, *args, **kwargs):
mask_tf = tf.zeros([B, H, W])
not_mask_tf = 1 - mask_tf
y_embed_tf = tf.cumsum(not_mask_tf, axis=1)
x_embed_tf = tf.cumsum(not_mask_tf, axis=2)
y_embed_tf = y_embed_tf / (y_embed_tf[:, -1:, :] + self.eps) * self.scale # 2 * math.pi
x_embed_tf = x_embed_tf / (x_embed_tf[:, :, -1:] + self.eps) * self.scale # 2 * math.pi
dim_t_tf = tf.range(self.hidden_dim, dtype=tf.float32)
dim_t_tf = self.temperature ** (2 * (dim_t_tf // 2) / self.hidden_dim)
pos_x_tf = x_embed_tf[:, :, :, None] / dim_t_tf
pos_y_tf = y_embed_tf[:, :, :, None] / dim_t_tf
pos_x_tf = tf.reshape(tf.stack([tf.math.sin(pos_x_tf[:, :, :, 0::2]),
tf.math.cos(pos_x_tf[:, :, :, 1::2])], axis=4),
shape=[B, H, W, self.hidden_dim])
pos_y_tf = tf.reshape(tf.stack([tf.math.sin(pos_y_tf[:, :, :, 0::2]),
tf.math.cos(pos_y_tf[:, :, :, 1::2])], axis=4),
shape=[B, H, W, self.hidden_dim])
pos_tf = tf.concat([pos_y_tf, pos_x_tf], axis=-1)
pos_tf = self.token_projection(pos_tf)
return pos_tf
def get_config(self):
base_config = super().get_config()
base_config.update({"token_projection": self.token_projection, "scale": self.scale,
"temperature": self.temperature, "hidden_dim": self.hidden_dim,
"dim": self.dim, "eps": self.eps})
return base_config
def EdgeNeXt(input_shape=(256, 256, 3), depths=[3, 3, 9, 3], dims=[24, 48, 88, 168],
global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'],
drop_path_rate=1, layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4,
kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False],
use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], epsilon=1e-6, model_name='EdgeNeXt'):
inputs = keras.layers.Input(input_shape, batch_size=2)
nn = conv2d_no_bias(inputs, dims[0], kernel_size=4, strides=4, padding="valid", name="stem_")
nn = layer_norm(nn, epsilon=epsilon, name='stem_')
drop_connect_rates = tf.linspace(0, stop=drop_path_rate, num=int(
sum(depths))) # drop_connect_rates_split(num_blocks, start=0.0, end=drop_connect_rate)
cur = 0
for i in range(4):
for j in range(depths[i]):
if j > depths[i] - global_block[i] - 1:
if global_block_type[i] == 'SDTA':
SDTA_encoder(dim=dims[i], drop_path=drop_connect_rates[cur + j],
expan_ratio=expan_ratio, scales=d2_scales[i],
use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i], name='stage_'+str(i)+'_SDTA_encoder_'+str(j))(nn)
else:
raise NotImplementedError
else:
if i != 0 and j == 0:
nn = layer_norm(nn, epsilon=epsilon, name='stage_' + str(i) + '_')
nn = conv2d_no_bias(nn, dims[i], kernel_size=2, strides=2, padding="valid",
name='stage_' + str(i) + '_')
Conv_Encoder(dim=dims[i], drop_path=drop_connect_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
expan_ratio=expan_ratio, kernel_size=kernel_sizes[i], name='stage_'+str(i)+'_Conv_Encoder_'+str(j) + '_')(nn) # drop_connect_rates[cur + j]
model = keras.models.Model(inputs, nn, name=model_name)
return model
@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class Conv_Encoder(keras.layers.Layer):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7, epsilon=1e-6,
name=''):
super(Conv_Encoder, self).__init__()
self.encoder_name = name
self.gamma = tf.Variable(layer_scale_init_value * tf.ones(dim), trainable=True,
name=name + 'gamma') if layer_scale_init_value > 0 else None
self.drop_path = drop_path
self.dim = dim
self.expan_ratio = expan_ratio
self.kernel_size = kernel_size
self.epsilon = epsilon
def __call__(self, x, *args, **kwargs):
inputs = x
x = keras.layers.Conv2D(self.dim, kernel_size=self.kernel_size, padding="SAME", name=self.encoder_name +'Conv2D')(x)
x = layer_norm(x, epsilon=self.epsilon, name=self.encoder_name)
x = keras.layers.Dense(self.expan_ratio * self.dim)(x)
x = activation_by_name(x, activation="gelu")
x = keras.layers.Dense(self.dim)(x)
if self.gamma is not None:
x = self.gamma * x
x = inputs + drop_block(x, drop_rate=0.)
return x
def get_config(self):
base_config = super().get_config()
base_config.update({"gamma": self.gamma, "drop_path": self.drop_path,
"dim": self.dim, "expan_ratio": self.expan_ratio,
"kernel_size": self.kernel_size})
return base_config
@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class SDTA_encoder(keras.layers.Layer):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4,
use_pos_emb=True, num_heads=8, qkv_bias=True, attn_drop=0., drop=0., scales=1, zero_gamma=False,
activation='gelu', use_bias=False, name='sdf'):
super(SDTA_encoder, self).__init__()
self.expan_ratio = expan_ratio
self.width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales)))
self.width_list = [self.width] * (scales - 1)
self.width_list.append(dim - self.width * (scales - 1))
self.dim = dim
self.scales = scales
if scales == 1:
self.nums = 1
else:
self.nums = scales - 1
self.pos_embd = None
if use_pos_emb:
self.pos_embd = PositionalEncodingFourier(dim=dim)
self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.gamma_xca = tf.Variable(layer_scale_init_value * tf.ones(dim), trainable=True,
name=name + 'gamma') if layer_scale_init_value > 0 else None
self.gamma = tf.Variable(layer_scale_init_value * tf.ones(dim), trainable=True,
name=name + 'gamma') if layer_scale_init_value > 0 else None
self.drop_rate = drop_path
self.drop_path = keras.layers.Dropout(drop_path)
gamma_initializer = tf.zeros_initializer() if zero_gamma else tf.ones_initializer()
self.norm = keras.layers.LayerNormalization(epsilon=LAYER_NORM_EPSILON, gamma_initializer=gamma_initializer,
name=name and name + "ln")
self.norm_xca = keras.layers.LayerNormalization(epsilon=LAYER_NORM_EPSILON, gamma_initializer=gamma_initializer,
name=name and name + "norm_xca")
self.activation = activation
self.use_bias = use_bias
def get_config(self):
base_config = super().get_config()
base_config.update({"width": self.width, "dim": self.dim,
"nums": self.nums, "pos_embd": self.pos_embd,
"xca": self.xca, "gamma_xca": self.gamma_xca,
"gamma": self.gamma, "norm": self.norm,
"activation": self.activation, "use_bias": self.use_bias,
})
return base_config
def __call__(self, inputs, *args, **kwargs):
x = inputs
spx = tf.split(inputs, self.width_list, axis=-1)
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = keras.layers.Conv2D(self.width, kernel_size=3, padding='SAME')(sp) # , groups=self.width
if i == 0:
out = sp
else:
out = tf.concat([out, sp], -1)
inputs = tf.concat([out, spx[self.nums]], -1)
# XCA
B, H, W, C = inputs.shape
inputs = tf.reshape(inputs, (-1, H * W, C)) # tf.transpose(), perm=[0, 2, 1])
if self.pos_embd:
pos_encoding = tf.reshape(self.pos_embd(B, H, W), (-1, H * W, C))
inputs += pos_encoding
if self.gamma_xca is not None:
inputs = self.gamma_xca * inputs
input_xca = self.gamma_xca * self.xca(self.norm_xca(inputs))
inputs = inputs + drop_block(input_xca, drop_rate=self.drop_rate, name="SDTA_encoder_")
inputs = tf.reshape(inputs, (-1, H, W, C))
# Inverted Bottleneck
inputs = self.norm(inputs)
inputs = keras.layers.Conv2D(self.expan_ratio * self.dim, kernel_size=1, use_bias=self.use_bias)(inputs)
inputs = activation_by_name(inputs, activation=self.activation)
inputs = keras.layers.Conv2D(self.dim, kernel_size=1, use_bias=self.use_bias)(inputs)
if self.gamma is not None:
inputs = self.gamma * inputs
x = x + self.drop_path(inputs)
return x
@tf.keras.utils.register_keras_serializable(package="EdgeNeXt")
class XCA(keras.layers.Layer):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., name=""):
super(XCA, self).__init__()
self.num_heads = num_heads
self.temperature = tf.Variable(tf.ones(num_heads, 1, 1), trainable=True, name=name + 'gamma')
self.qkv = keras.layers.Dense(dim * 3, use_bias=qkv_bias)
self.attn_drop = keras.layers.Dropout(attn_drop)
self.k_ini = initializers.GlorotUniform()
self.b_ini = initializers.Zeros()
self.proj = keras.layers.Dense(dim, name="out",
kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
self.proj_drop = keras.layers.Dropout(proj_drop)
def __call__(self, inputs, training=None, *args, **kwargs):
input_shape = inputs.shape
qkv = self.qkv(inputs)
qkv = tf.reshape(qkv, (input_shape[0], input_shape[1], 3,
self.num_heads,
input_shape[2] // self.num_heads)) # [batch, hh * ww, 3, num_heads, dims_per_head]
qkv = tf.transpose(qkv, perm=[2, 0, 3, 4, 1]) # [3, batch, num_heads, dims_per_head, hh * ww]
query, key, value = tf.split(qkv, 3, axis=0) # [batch, num_heads, dims_per_head, hh * ww]
norm_query, norm_key = tf.nn.l2_normalize(tf.squeeze(query), axis=-1, epsilon=1e-6), \
tf.nn.l2_normalize(tf.squeeze(key), axis=-1, epsilon=1e-6)
attn = tf.matmul(norm_query, norm_key, transpose_b=True)
attn = tf.transpose(tf.transpose(attn, perm=[0, 2, 3, 1]) * self.temperature, perm=[0, 3, 2, 1])
attn = tf.nn.softmax(attn, axis=-1)
attn = self.attn_drop(attn, training=training) # [batch, num_heads, hh * ww, hh * ww]
x = tf.matmul(attn, value) # [batch, num_heads, hh * ww, dims_per_head]
x = tf.reshape(x, [input_shape[0], input_shape[1], input_shape[2]])
x = self.proj(x)
x = self.proj_drop(x)
return x
def get_config(self):
base_config = super().get_config()
base_config.update({"num_heads": self.num_heads, "temperature": self.temperature,
"qkv": self.qkv, "attn_drop": self.attn_drop,
"proj": self.proj, "proj_drop": self.proj_drop})
return base_config
def edgenext_xx_small(pretrained=False, **kwargs):
# 1.33M & 260.58M @ 256 resolution
# 71.23% Top-1 accuracy
# No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
# For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
model = EdgeNeXt(depths=[2, 2, 6, 2], dims=[24, 48, 88, 168], expan_ratio=4,
global_block=[0, 1, 1, 1],
global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
use_pos_embd_xca=[False, True, False, False],
kernel_sizes=[3, 5, 7, 9],
heads=[4, 4, 4, 4],
d2_scales=[2, 2, 3, 4],
**kwargs)
return model
def edgenext_x_small(pretrained=False, **kwargs):
# 2.34M & 538.0M @ 256 resolution
# 75.00% Top-1 accuracy
# No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=31.61 versus 28.49 for MobileViT_XS
# For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[32, 64, 100, 192], expan_ratio=4,
global_block=[0, 1, 1, 1],
global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
use_pos_embd_xca=[False, True, False, False],
kernel_sizes=[3, 5, 7, 9],
heads=[4, 4, 4, 4],
d2_scales=[2, 2, 3, 4],
**kwargs)
return model
def edgenext_small(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[48, 96, 160, 304], expan_ratio=4,
global_block=[0, 1, 1, 1],
global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'],
use_pos_embd_xca=[False, True, False, False],
kernel_sizes=[3, 5, 7, 9],
d2_scales=[2, 2, 3, 4],
**kwargs)
return model
if __name__ == '__main__':
model = edgenext_small()
model.summary()
# from download_and_load import keras_reload_from_torch_model
# keras_reload_from_torch_model(
# 'D:\GitHub\EdgeNeXt\edgenext_small.pth',
# keras_model=model,
# # tail_align_dict=tail_align_dict,
# # full_name_align_dict=full_name_align_dict,
# # additional_transfer=additional_transfer,
# input_shape=(256, 256),
# do_convert=True,
# save_name="adaface_ir101_webface4m.h5",
# )
```