Comments (2)
ResNet50 핵심 구현 사항 3가지 (구현은 ResNet50이라는 하나의 클래스로 작업할 것)
-
이미지 데이터를 넣는 방법
- MNIST와 같이 이미 전처리가 되어 있는 데이터 말고, 프로그래머스 과제와 같은 이미지를 내가 직접 구현한 ResNet에 넣는 방법?
- 똑같이 ImageGenerator를 사용하고 넣으면 되려나? 앞의 2개는 구글링 좀 해보니, 내용에 대해 감이 오는데, 만든 코드에 이미지 넣는 게 헷갈린다.
ResNet50 구현 마무리에 거의다 왔다. 각각의 코드를 함수로 만들어서 작업하는 것까지 감이 오는데, 이걸 클래스로 하나의 모델로 만드는 작업이 첫번째 논문 코드 구현의 백치다.
생각보다 논문 한 개 보고 이해하는데 꽤 오랜 시간이 걸렸는데, 수학적인 내용까지 모두 커버하면서 적용하려니 이해안되는 부분들이 꽤 있다.
from dl-paper-review-and-code-practice.
ResNet50 CIFAR-10 이미지 모델에 적용한 코드 초안
그런데 학습 속도가 너무 느리다. 이미지가 6만장밖에 안되는데 에폭 1회당 45분씩 걸린다.
뭐가 문제지?
#!/usr/bin/env python
# coding: utf-8
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.initializers import random_uniform, glorot_uniform, constant, identity
from keras.utils import np_utils
get_ipython().run_line_magic('matplotlib', 'inline')
def identity_block(X, f, filters, training=True, initializer=random_uniform):
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters = F1, kernel_size = 1, strides = (1,1), padding = 'valid', kernel_initializer = initializer(seed=0))(X)
X = BatchNormalization(axis = 3)(X, training = training) # Default axis
X = Activation('relu')(X)
X = Conv2D(filters = F2, kernel_size = f, strides = (1,1), padding = 'same', kernel_initializer = initializer(seed=0))(X)
X = BatchNormalization(axis = 3)(X, training = training)
X = Activation('relu')(X)
X = Conv2D(filters = F3, kernel_size = 1, strides = (1,1), padding = 'valid', kernel_initializer = initializer(seed=0))(X)
X = BatchNormalization(axis = 3)(X, training = training)
X = Add()([X, X_shortcut])
X = Activation('relu')(X)
return X
def convolutional_block(X, f, filters, s = 2, training=True, initializer=glorot_uniform):
F1, F2, F3 = filters
X_shortcut = X
X = Conv2D(filters = F1, kernel_size = 1, strides = (s, s), padding='valid', kernel_initializer = initializer(seed=0))(X)
X = BatchNormalization(axis = 3)(X, training=training)
X = Activation('relu')(X)
X = Conv2D(filters = F2, kernel_size = f, strides = (1, 1), padding='same', kernel_initializer = initializer(seed=0))(X)
X = BatchNormalization(axis = 3)(X, training=training)
X = Activation('relu')(X)
X = Conv2D(filters = F3, kernel_size = 1, strides = (1, 1), padding='valid', kernel_initializer = initializer(seed=0))(X)
X = BatchNormalization(axis = 3)(X, training=training)
X_shortcut = Conv2D(filters = F3, kernel_size = 1, strides = (s, s), padding='valid', kernel_initializer = initializer(seed=0))(X_shortcut)
X_shortcut = BatchNormalization(axis = 3)(X_shortcut, training=training)
X = Add()([X, X_shortcut])
X = Activation('relu')(X)
return X
def ResNet50(input_shape = (32, 32, 3), classes = 10):
X_input = Input(input_shape)
X = ZeroPadding2D((3, 3))(X_input)
X = Conv2D(64, (7, 7), strides = (2, 2), kernel_initializer = glorot_uniform(seed=0))(X)
X = BatchNormalization(axis = 3)(X)
X = Activation('relu')(X)
X = MaxPooling2D((3, 3), strides=(2, 2))(X)
X = convolutional_block(X, f = 3, filters = [64, 64, 256], s = 1)
X = identity_block(X, 3, [64, 64, 256])
X = identity_block(X, 3, [64, 64, 256])
X = convolutional_block(X, f = 3, filters = [128, 128, 512], s = 2)
X = identity_block(X, 3, [128, 128, 512])
X = identity_block(X, 3, [128, 128, 512])
X = identity_block(X, 3, [128, 128, 512])
X = convolutional_block(X, f = 3, filters = [256, 256, 1024], s = 2)
X = identity_block(X, 3, [256, 256, 1024])
X = identity_block(X, 3, [256, 256, 1024])
X = identity_block(X, 3, [256, 256, 1024])
X = identity_block(X, 3, [256, 256, 1024])
X = identity_block(X, 3, [256, 256, 1024])
X = convolutional_block(X, f = 3, filters = [512, 512, 2048], s = 2)
X = identity_block(X, 3, [512, 512, 2048])
X = identity_block(X, 3, [512, 512, 2048])
X = AveragePooling2D(pool_size=(2,2), padding="same")(X)
X = Flatten()(X)
X = Dense(classes, activation='softmax', kernel_initializer = glorot_uniform(seed=0))(X)
model = Model(inputs = X_input, outputs = X)
return model
m_model = ResNet50(input_shape = (32, 32, 3), classes = 10)
m_model.compile(loss='categorical_crossentropy', optimizer="adam", metrics=['accuracy'])
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train / 255.
x_test = x_test / 255.
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print ("number of training examples = " + str(x_train.shape[0]))
print ("number of test examples = " + str(x_test.shape[0]))
print ("X_train shape: " + str(x_train.shape))
print ("Y_train shape: " + str(y_train.shape))
print ("X_test shape: " + str(x_test.shape))
print ("Y_test shape: " + str(y_test.shape))
history = m_model.fit(x_train, y_train, epochs=20, batch_size=512, validation_data=(x_test, y_test), verbose=1, shuffle=True)
m_model.evaluate(x_test,y_test,batch_size=256, verbose=1)
from dl-paper-review-and-code-practice.
Related Issues (3)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dl-paper-review-and-code-practice.