Giter Club home page Giter Club logo

Comments (2)

RyanKor avatar RyanKor commented on September 27, 2024

ResNet50 핵심 구현 사항 3가지 (구현은 ResNet50이라는 하나의 클래스로 작업할 것)

  • shortcut을 사용한 Identity block 구현
    image

  • ResNet 50 Layer 구현
    image

  • 이미지 데이터를 넣는 방법

    • MNIST와 같이 이미 전처리가 되어 있는 데이터 말고, 프로그래머스 과제와 같은 이미지를 내가 직접 구현한 ResNet에 넣는 방법?
    • 똑같이 ImageGenerator를 사용하고 넣으면 되려나? 앞의 2개는 구글링 좀 해보니, 내용에 대해 감이 오는데, 만든 코드에 이미지 넣는 게 헷갈린다.

ResNet50 구현 마무리에 거의다 왔다. 각각의 코드를 함수로 만들어서 작업하는 것까지 감이 오는데, 이걸 클래스로 하나의 모델로 만드는 작업이 첫번째 논문 코드 구현의 백치다.

생각보다 논문 한 개 보고 이해하는데 꽤 오랜 시간이 걸렸는데, 수학적인 내용까지 모두 커버하면서 적용하려니 이해안되는 부분들이 꽤 있다.

from dl-paper-review-and-code-practice.

RyanKor avatar RyanKor commented on September 27, 2024

ResNet50 CIFAR-10 이미지 모델에 적용한 코드 초안

그런데 학습 속도가 너무 느리다. 이미지가 6만장밖에 안되는데 에폭 1회당 45분씩 걸린다.

뭐가 문제지?

#!/usr/bin/env python
# coding: utf-8

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.initializers import random_uniform, glorot_uniform, constant, identity
from keras.utils import np_utils
get_ipython().run_line_magic('matplotlib', 'inline')

def identity_block(X, f, filters, training=True, initializer=random_uniform):
    
    F1, F2, F3 = filters
    
    X_shortcut = X
    
    X = Conv2D(filters = F1, kernel_size = 1, strides = (1,1), padding = 'valid', kernel_initializer = initializer(seed=0))(X)
    X = BatchNormalization(axis = 3)(X, training = training) # Default axis
    X = Activation('relu')(X)

    X = Conv2D(filters = F2, kernel_size = f, strides = (1,1), padding = 'same', kernel_initializer = initializer(seed=0))(X)
    X = BatchNormalization(axis = 3)(X, training = training)
    X = Activation('relu')(X)

    X = Conv2D(filters = F3, kernel_size = 1, strides = (1,1), padding = 'valid', kernel_initializer = initializer(seed=0))(X)
    X = BatchNormalization(axis = 3)(X, training = training)
    
    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)

    return X

def convolutional_block(X, f, filters, s = 2, training=True, initializer=glorot_uniform):

    F1, F2, F3 = filters
   
    X_shortcut = X

    X = Conv2D(filters = F1, kernel_size = 1, strides = (s, s), padding='valid', kernel_initializer = initializer(seed=0))(X)
    X = BatchNormalization(axis = 3)(X, training=training)
    X = Activation('relu')(X)

    X = Conv2D(filters = F2, kernel_size = f, strides = (1, 1), padding='same', kernel_initializer = initializer(seed=0))(X)
    X = BatchNormalization(axis = 3)(X, training=training)
    X = Activation('relu')(X)

    X = Conv2D(filters = F3, kernel_size = 1, strides = (1, 1), padding='valid', kernel_initializer = initializer(seed=0))(X)
    X = BatchNormalization(axis = 3)(X, training=training)
    
    X_shortcut = Conv2D(filters = F3, kernel_size = 1, strides = (s, s), padding='valid', kernel_initializer = initializer(seed=0))(X_shortcut)
    X_shortcut = BatchNormalization(axis = 3)(X_shortcut, training=training)

    X = Add()([X, X_shortcut])
    X = Activation('relu')(X)
    
    return X

def ResNet50(input_shape = (32, 32, 3), classes = 10):
    
    X_input = Input(input_shape)
    
    X = ZeroPadding2D((3, 3))(X_input)

    X = Conv2D(64, (7, 7), strides = (2, 2), kernel_initializer = glorot_uniform(seed=0))(X)
    X = BatchNormalization(axis = 3)(X)
    X = Activation('relu')(X)
    X = MaxPooling2D((3, 3), strides=(2, 2))(X)

    X = convolutional_block(X, f = 3, filters = [64, 64, 256], s = 1)
    X = identity_block(X, 3, [64, 64, 256])
    X = identity_block(X, 3, [64, 64, 256])

    X = convolutional_block(X, f = 3, filters = [128, 128, 512], s = 2)
    X = identity_block(X, 3, [128, 128, 512])
    X = identity_block(X, 3, [128, 128, 512])
    X = identity_block(X, 3, [128, 128, 512])

    X = convolutional_block(X, f = 3, filters = [256, 256, 1024], s = 2)
    X = identity_block(X, 3, [256, 256, 1024])
    X = identity_block(X, 3, [256, 256, 1024])
    X = identity_block(X, 3, [256, 256, 1024])
    X = identity_block(X, 3, [256, 256, 1024])
    X = identity_block(X, 3, [256, 256, 1024]) 

    X = convolutional_block(X, f = 3, filters = [512, 512, 2048], s = 2)
    X = identity_block(X, 3, [512, 512, 2048]) 
    X = identity_block(X, 3, [512, 512, 2048])  

    X = AveragePooling2D(pool_size=(2,2), padding="same")(X)
    

    X = Flatten()(X)
    X = Dense(classes, activation='softmax', kernel_initializer = glorot_uniform(seed=0))(X)    
    
    model = Model(inputs = X_input, outputs = X)

    return model

m_model = ResNet50(input_shape = (32, 32, 3), classes = 10)

m_model.compile(loss='categorical_crossentropy', optimizer="adam", metrics=['accuracy'])

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

x_train = x_train / 255.
x_test = x_test / 255.

y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print ("number of training examples = " + str(x_train.shape[0]))
print ("number of test examples = " + str(x_test.shape[0]))
print ("X_train shape: " + str(x_train.shape))
print ("Y_train shape: " + str(y_train.shape))
print ("X_test shape: " + str(x_test.shape))
print ("Y_test shape: " + str(y_test.shape))

history = m_model.fit(x_train, y_train, epochs=20, batch_size=512, validation_data=(x_test, y_test), verbose=1, shuffle=True)

m_model.evaluate(x_test,y_test,batch_size=256, verbose=1)

from dl-paper-review-and-code-practice.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.