본문 바로가기
archive.tar

[TensorFlow] Inception-Resnet-v2 원하는 이미지 학습과 추론 해보기

by 냉동만두 2017. 8. 29.

개요


Inception-Resnet-v2 모델을 사용하여 이미지를 재학습 후 추론해본다.

Inception-Resnet-v2 image retrain classification


준비 : TF-Slim


https://github.com/tensorflow/models/tree/master/slim


tensorflow slim, inception 소스코드를 받아 놓는다

git clone https://github.com/tensorflow/models.git


준비 : 작업 공간 만들기


적당히 작업할 공간을 만들어 놓는다.

mkdir tfrecord

mkdir train



tfrecord : 변환된 데이터세트 저장

train : 학습 후 출력되는 체크포인트 파일 저장


준비 : 이미지 데이터세트 변환 tfrecord


준비된 이미지들을 tfrecord로 변환 하는 방법은 아래 주소를 참고 한다.

http://gusrb.tistory.com/12

간단하게 준비된 Flowers 데이터세트를 사용할 수도 있다.


준비 : Flowers 데이터세트


python models/slim/download_and_convert_data.py \
    --dataset_name=flowers \
    --dataset_dir=/work/flowers/tfrecord


--dataset_name : 데이터세트 이름. flowers 라고 적어준다

--dataset_dir : 데이터세트를 저장할 경로


위에서 받은 slim 소스코드 내에 download_and_convert_data.py 스크립트를 위와 같이 실행하면

사용하기 편한 준비된 tfrecord 데이터 세트를 다운로드 한다

이번 포스팅에서는 Flowers 데이터세트를 사용 한다.


학습 : inception-resnet-v2 fine-tune retrain


inception-resnet-v2 를 사용한 학습은 아래 주소를 참고하면 된다.

http://gusrb.tistory.com/20

위 주소의 포스팅에서 '학습 : 미리 훈련된 모델에 플라워데이터를 미세조정 학습 시키기' 까지만 수행하면 된다.

오류 없이 잘 수행 했다면 train 폴더에 ~~~.ckpt 체크포인트 파일이 생성 된다.

이 파일들을 사용 한다.



변환 : inception-resnet-v2 ckpt to pb


ckpt 체크포인트 파일을 pb파일로 덤프 한다.

덤프를 하여 pb파일을 사용하게 되면 기존의 추론 스크립트를 재활용 할 수 있다.


models/slim/ 아래에 ckpt_pb.py 스크립트를 작성 한다

이번 변환 스크립트는 꼭 slim 소스 폴더에서 실행해야 한다.


아래는 스크립트 전문이다.

색으로 표시한 부분은 아래를 참고 한다.


ckpt_pb.py

import tensorflow as tf
from tensorflow.contrib import slim

from nets import inception
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from preprocessing import inception_preprocessing

checkpoints_dir = '/home/hwang/work/flowers/train'
OUTPUT_PB_FILENAME = 'minimal_graph.proto'
NUM_CLASSES = 5

# We need default size of image for a particular network.
# The network was trained on images of that size -- so we
# resize input image later in the code.
image_size = inception.inception_resnet_v2.default_image_size

with tf.Graph().as_default():
    # Inject placeholder into the graph
    input_image_t = tf.placeholder(tf.string, name='input_image')
    image = tf.image.decode_jpeg(input_image_t, channels=3)

    # Resize the input image, preserving the aspect ratio
    # and make a central crop of the resulted image.
    # The crop will be of the size of the default image size of
    # the network.
    # I use the "preprocess_for_eval()" method instead of "inception_preprocessing()"
    # because the latter crops all images to the center by 85% at
    # prediction time (training=False).
    processed_image = inception_preprocessing.preprocess_for_eval(image,
                                                                  image_size,
                                                                  image_size, central_fraction=None)

    # Networks accept images in batches.
    # The first dimension usually represents the batch size.
    # In our case the batch size is one.
    processed_images = tf.expand_dims(processed_image, 0)

    # Load the inception network structure
    with slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
        logits, _ = inception.inception_resnet_v2(processed_images,
                                                  num_classes=NUM_CLASSES,
                                                  is_training=False)
    # Apply softmax function to the logits (output of the last layer of the network)
    probabilities = tf.nn.softmax(logits)

    model_path = tf.train.latest_checkpoint(checkpoints_dir)

    # Get the function that initializes the network structure (its variables) with
    # the trained values contained in the checkpoint
    init_fn = slim.assign_from_checkpoint_fn(
        model_path,
        slim.get_model_variables())

    with tf.Session() as sess:
        # Now call the initialization function within the session
        init_fn(sess)

        # Convert variables to constants and make sure the placeholder input_image is included
        # in the graph as well as the other neccesary tensors.
        constant_graph = convert_variables_to_constants(sess, sess.graph_def, ["input_image", "DecodeJpeg",
                                                                               "InceptionResnetV2/Logits/Predictions"])

        # Define the input and output layer properly
        optimized_constant_graph = optimize_for_inference(constant_graph, ["input_image"],
                                                          ["InceptionResnetV2/Logits/Predictions"],
                                                          tf.string.as_datatype_enum)
        # Write the production ready graph to file.
        tf.train.write_graph(optimized_constant_graph, '.', OUTPUT_PB_FILENAME, as_text=False)


checkpoints_dir = '위에서/만든/체크포인트 파일이 있는 경로'
OUTPUT_PB_FILENAME = '경로/출력할 덤프 파일 이름.proto'
NUM_CLASSES = 5 학습 시킨 label 갯수. class의 갯수.


input_image : 수정해도 좋지만 문자열은 통일 시켜 주어야 한다. 수정 안하는걸 추천.


$ python ckpt_pb.txt

완료 되면 위와 같은 화면이 출력 되고 아래와 같은 파일이 출력 된다.

체크포인트 파일을 덤프한 pb 파일이다.


* 출저 : https://stackoverflow.com/questions/39902596/tensorflow-inception-resnet-v2-classify-image


추론 : inception-resnet-v2 inferenc


변환한 덤프 파일, 위에서 만든 label.txt 파일을 가지고 추론을 한다.

아래에 올려진 추론 스크립트는 지난 inception-v3 추론 스크립트와 같다.

다만, 그 안에서 필요한 텐서의 이름만 수정해서 사용 한다.


label_image.py

import tensorflow as tf
import sys


# change this as you see fit
image_path = sys.argv[1]

# Read in the image_data
image_data = tf.gfile.FastGFile(image_path, 'rb').read()

# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
                   in tf.gfile.GFile("labels.txt")]

# Unpersists graph from file
with tf.gfile.FastGFile("minimal_graph.proto", 'rb') as f:
    #with tf.device('/gpu:0'):
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')

#config=tf.ConfigProto(log_device_placement=True)

with tf.Session() as sess:
    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
    predictions = sess.run(softmax_tensor, \
                           {'input_image:0': image_data})
    # Sort to show labels of first prediction in order of confidence
    top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
    print(top_k)

    cnt = 0
    for node_id in top_k:
        human_string = label_lines[node_id]
        score = predictions[0][node_id]
        print''
        print('%s (score = %.5f)' % (human_string, score))
        cnt = cnt + 1
        # if(cnt == 3):
        #        break
print''


minimal_graph.proto : 위에서 변환한 덤프 파일 경로/이름.proto
labels.txt : 위에서 데이터세트 만들때 작성한 경로/labels.txt


InceptionResnetV2/Logits/Prediction : 마지막 추론에 사용할 텐서 이름.

input_image : 이미지 디코딩 텐서 이름. 변환 스크립트 작성할 때 수정을 안했다면 그대로 작성.


$python label_image.py test.jpg


추론에 성공한 화면이다. 스코어가 낮은 이유는 테스트를 위해 학습을 아주 단시간 진행 했기 때문이다.


부록 : 이미지 추론 서버 만들기


http://gusrb.tistory.com/47