728x90

 대학원 생활에 제가 직접 사용하였던 EfficientNet의 학습코드를 설명하기에 앞서 간단하게 설명을 드리면 일반적으로 이미지 분류 분야의 인공지능 모델의 정확도를 높일 때 사용되는 조건은 아래의 사진 처럼 모델의 깊이, 너비, 입력 이미지라는 값을 변화시키는데 EfficientNet의 경우 이 3가지의 조건을 효율적으로 조절할 수 있는 compound scaling 방법을 제안해서 모델의 정확도를 높입니다.

모델의 정확도를 높이는 조건

 

 또 EfficientNet의 경우 모델이 사용하는 자원이 제한된 상태에서 모델의 정확도를 최대화하는 문제를 해결하고자 하는 목적성을 가지고 있습니다. 자원이 제한된 상황은 다음과 같은 수식으로 표현됩니다.

EfficientNet의 경우 MnasNet에 기반한 baseline network를 사용하고 있으며 아래의 사진은 EfficientNet의 아키텍처입니다. 모델이 b0~b7까지로 구성이 되어있고 b0에서 b7으로 점점 높은 모델을 사용할 때 마다 모델의 파라미터 개수가 늘어나고 모델의 크기가 상승하게 됩니다.

EfficientNet 아키텍처
 

 논문 원작자가 모델 검증을 할 때 사용했던 데이터 세트는 ImageNet을 사용하였으며 여러가지 CNN(합성곱 신경망) 모델들과 비교했을 때 매개 변수의 크기와 컴퓨터의 성능을 수치로 나타내는 단위인 FLOPS를 다른 모델의 비해 10배정도까지 줄였음에도 불구하고 다른 모델보다 정확도 면에서 우수하게 나온 모델이다. 자세한 사항은 아래의 논문 링크를 첨부하였으니 읽어보시는 것을 추천드립니다.

EfficientNet acc
 

논문 링크 : https://arxiv.org/abs/1905.11946

 

EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks

Convolutional Neural Networks (ConvNets) are commonly developed at a fixed resource budget, and then scaled up for better accuracy if more resources are available. In this paper, we systematically study model scaling and identify that carefully balancing n

arxiv.org

 

이제 본격적으로 코딩을 해보자면 먼저 학습 코드를 진행하기전에 데이터를 충분히 확보하셨다면 아래의 labeling 코드를 이용해서 데이터 세트를 구성하시면 됩니다.

 

먼저 필요한 라이브러리들을 import 시켜줍니다.

import os, shutil
import numpy as np
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns

 

오류가 없이 정상적으로 라이브러리들이 import 되셨으면 사진 파일들이 있는 root 경로를 설정해줍니다. 또한 class 개수 만큼 path들을 만들어 줍니다.

저의 경우 대학원 과정에서 진행했던 폭발물 데이터를 이용했습니다.

root = "/home/user/effcientNet/explosive/RP_explosive_test"
A_path = [] #NG_10
B_path = [] #NG_50
C_path = [] #NG_100
D_path = [] #NG_200
E_path = [] #Normal
F_path = [] #PETN_10
G_path = [] #PETN_50
H_path = [] #PETN_100
I_path = [] #PETN_200
J_path = [] #RDX_10
K_path = [] #RDX_50
L_path = [] #RDX_100
M_path = [] #RDX_200
N_path = [] #TNT_10
O_path = [] #TNT_50
P_path = [] #TNT_100
Q_path = [] #TNT_200

이제 path마다 class의 파일들을 넣는 과정을 진행해봅시다.

def fast_scandir(dirname):
    subfolders= [f.path for f in os.scandir(dirname) if f.is_dir()]
    for dirname in list(subfolders):
        subfolders.extend(fast_scandir(dirname))
    return subfolders

bombTypes = ["NG_10ng","NG_50ng","NG_100ng","NG_200ng","Normal","PETN_10ng","PETN_50ng","PETN_100ng","PETN_200ng","RDX_10ng","RDX_50ng","RDX_100ng","RDX_200ng","TNT_10ng","TNT_50ng","TNT_100ng","TNT_200ng"]
bombPaths = [A_path,B_path,C_path,D_path,E_path,F_path,G_path,H_path,I_path,J_path,K_path,L_path,M_path,N_path,O_path,P_path,Q_path]

#여기서 bomTypes는 label이름을 bombPaths는 class마다 지정한 path를 지정해주시면 됩니다.

total_path = [s for s in fast_scandir(root) if any(xs for xs in bombTypes)]

for types, pathlist in zip(bombTypes,bombPaths):
    for path in total_path:
        if (types in path):
            imagePath = glob(path+"/*.png")
            for i in imagePath:
                if("textured" not in i):
                    pathlist.append(i)
                    
# 이 과정이 끝나면 각 path에 파일들의 경로들이 배열에 저장될겁니다.

 

이제 데이터가 배열에 잘저장되었는지 matplotlib 라이브러리를 이용해서 그래프로 띄워봅시다.

index = ["A","B","C","D","E","F","G","H","I","J","K","L","M","N","O","P","Q"]
bardata = [len(i) for i in bombPaths ]

plt.title("Bomb total count", fontsize = 20)
plt.bar(index,bardata, color=['tab:blue','tab:orange','tab:green','tab:red',"tab:purple"])
for i, v in enumerate(index):
    plt.text(v, bardata[i], bardata[i], fontsize = 7, color='black',horizontalalignment='center',verticalalignment='bottom')
plt.show()

정상적으로 배열에 labeling이 된 것 같으니 이제 EfficientNet 학습에 사용될 class마다 label이 지정된 csv 파일 생성합시다.

 

sumdf = []
for num,pecies in enumerate(bombPaths):
    pecies = list(pecies)
    if len(pecies) > 10000:
        pecies = pecies
    label = np.empty_like(pecies)
    label.fill(num)
    a = np.stack([pecies,label],axis = 1)
    df = np.array(a)
    sumdf.append(df)

df = np.vstack(sumdf)
train_df = pd.DataFrame(df,columns = ["file_name","label"])
print(train_df)

#csv 폴더의 경로는 본인의 경로에 맞게 수정해 주셔야됩니다.
train_df.to_csv('/home/user/effcientNet/explosive/RP_explosive_test/test_df.csv',index=False)

정상적으로 잘 들어간것을 확인 할 수 있습니다.

위의 예시는 제가 R&D 과제를 하면서 진행했던 코드입니다. 본인의 경로에 꼭 맞추셔서 진행해주셔야 합니다.

다음 코드는 학습시키는 코드로 찾아뵙겠습니다.

728x90

+ Recent posts