개요
Inception-Resnet-v2 모델을 사용하여 이미지를 재학습 후 추론해본다.
Inception-Resnet-v2 image retrain classification
준비 : TF-Slim
https://github.com/tensorflow/models/tree/master/slim
tensorflow slim, inception 소스코드를 받아 놓는다
git clone https://github.com/tensorflow/models.git |
준비 : 작업 공간 만들기
적당히 작업할 공간을 만들어 놓는다.
mkdir tfrecord
mkdir train
tfrecord : 변환된 데이터세트 저장
train : 학습 후 출력되는 체크포인트 파일 저장
준비 : 이미지 데이터세트 변환 tfrecord
준비된 이미지들을 tfrecord로 변환 하는 방법은 아래 주소를 참고 한다.
http://gusrb.tistory.com/12
간단하게 준비된 Flowers 데이터세트를 사용할 수도 있다.
준비 : Flowers 데이터세트
python models/slim/download_and_convert_data.py \ --dataset_name=flowers \ --dataset_dir=/work/flowers/tfrecord
|
--dataset_name : 데이터세트 이름. flowers 라고 적어준다
--dataset_dir : 데이터세트를 저장할 경로
위에서 받은 slim 소스코드 내에 download_and_convert_data.py 스크립트를 위와 같이 실행하면
사용하기 편한 준비된 tfrecord 데이터 세트를 다운로드 한다
이번 포스팅에서는 Flowers 데이터세트를 사용 한다.
학습 : inception-resnet-v2 fine-tune retrain
inception-resnet-v2 를 사용한 학습은 아래 주소를 참고하면 된다.
http://gusrb.tistory.com/20
위 주소의 포스팅에서 '학습 : 미리 훈련된 모델에 플라워데이터를 미세조정 학습 시키기' 까지만 수행하면 된다.
오류 없이 잘 수행 했다면 train 폴더에 ~~~.ckpt 체크포인트 파일이 생성 된다.
이 파일들을 사용 한다.
변환 : inception-resnet-v2 ckpt to pb
ckpt 체크포인트 파일을 pb파일로 덤프 한다.
덤프를 하여 pb파일을 사용하게 되면 기존의 추론 스크립트를 재활용 할 수 있다.
models/slim/ 아래에 ckpt_pb.py 스크립트를 작성 한다
이번 변환 스크립트는 꼭 slim 소스 폴더에서 실행해야 한다.
아래는 스크립트 전문이다.
색으로 표시한 부분은 아래를 참고 한다.
ckpt_pb.py
import tensorflow as tf from tensorflow.contrib import slim
from nets import inception from tensorflow.python.framework.graph_util import convert_variables_to_constants from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference from preprocessing import inception_preprocessing
checkpoints_dir = '/home/hwang/work/flowers/train' OUTPUT_PB_FILENAME = 'minimal_graph.proto' NUM_CLASSES = 5
# We need default size of image for a particular network. # The network was trained on images of that size -- so we # resize input image later in the code. image_size = inception.inception_resnet_v2.default_image_size
with tf.Graph().as_default(): # Inject placeholder into the graph input_image_t = tf.placeholder(tf.string, name='input_image') image = tf.image.decode_jpeg(input_image_t, channels=3)
# Resize the input image, preserving the aspect ratio # and make a central crop of the resulted image. # The crop will be of the size of the default image size of # the network. # I use the "preprocess_for_eval()" method instead of "inception_preprocessing()" # because the latter crops all images to the center by 85% at # prediction time (training=False). processed_image = inception_preprocessing.preprocess_for_eval(image, image_size, image_size, central_fraction=None)
# Networks accept images in batches. # The first dimension usually represents the batch size. # In our case the batch size is one. processed_images = tf.expand_dims(processed_image, 0)
# Load the inception network structure with slim.arg_scope(inception.inception_resnet_v2_arg_scope()): logits, _ = inception.inception_resnet_v2(processed_images, num_classes=NUM_CLASSES, is_training=False) # Apply softmax function to the logits (output of the last layer of the network) probabilities = tf.nn.softmax(logits)
model_path = tf.train.latest_checkpoint(checkpoints_dir)
# Get the function that initializes the network structure (its variables) with # the trained values contained in the checkpoint init_fn = slim.assign_from_checkpoint_fn( model_path, slim.get_model_variables())
with tf.Session() as sess: # Now call the initialization function within the session init_fn(sess)
# Convert variables to constants and make sure the placeholder input_image is included # in the graph as well as the other neccesary tensors. constant_graph = convert_variables_to_constants(sess, sess.graph_def, ["input_image", "DecodeJpeg", "InceptionResnetV2/Logits/Predictions"])
# Define the input and output layer properly optimized_constant_graph = optimize_for_inference(constant_graph, ["input_image"], ["InceptionResnetV2/Logits/Predictions"], tf.string.as_datatype_enum) # Write the production ready graph to file. tf.train.write_graph(optimized_constant_graph, '.', OUTPUT_PB_FILENAME, as_text=False) |
checkpoints_dir = '위에서/만든/체크포인트 파일이 있는 경로'
OUTPUT_PB_FILENAME = '경로/출력할 덤프 파일 이름.proto'
NUM_CLASSES = 5 학습 시킨 label 갯수. class의 갯수.
input_image : 수정해도 좋지만 문자열은 통일 시켜 주어야 한다. 수정 안하는걸 추천.
$ python ckpt_pb.txt
완료 되면 위와 같은 화면이 출력 되고 아래와 같은 파일이 출력 된다.
체크포인트 파일을 덤프한 pb 파일이다.
* 출저 : https://stackoverflow.com/questions/39902596/tensorflow-inception-resnet-v2-classify-image
추론 : inception-resnet-v2 inferenc
변환한 덤프 파일, 위에서 만든 label.txt 파일을 가지고 추론을 한다.
아래에 올려진 추론 스크립트는 지난 inception-v3 추론 스크립트와 같다.
다만, 그 안에서 필요한 텐서의 이름만 수정해서 사용 한다.
label_image.py
import tensorflow as tf import sys
# change this as you see fit image_path = sys.argv[1]
# Read in the image_data image_data = tf.gfile.FastGFile(image_path, 'rb').read()
# Loads label file, strips off carriage return label_lines = [line.rstrip() for line in tf.gfile.GFile("labels.txt")]
# Unpersists graph from file with tf.gfile.FastGFile("minimal_graph.proto", 'rb') as f: #with tf.device('/gpu:0'): graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='')
#config=tf.ConfigProto(log_device_placement=True)
with tf.Session() as sess: # Feed the image_data as input to the graph and get first prediction softmax_tensor = sess.graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0') predictions = sess.run(softmax_tensor, \ {'input_image:0': image_data}) # Sort to show labels of first prediction in order of confidence top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] print(top_k)
cnt = 0 for node_id in top_k: human_string = label_lines[node_id] score = predictions[0][node_id] print'' print('%s (score = %.5f)' % (human_string, score)) cnt = cnt + 1 # if(cnt == 3): # break print'' |
minimal_graph.proto : 위에서 변환한 덤프 파일 경로/이름.proto
labels.txt : 위에서 데이터세트 만들때 작성한 경로/labels.txt
InceptionResnetV2/Logits/Prediction : 마지막 추론에 사용할 텐서 이름.
input_image : 이미지 디코딩 텐서 이름. 변환 스크립트 작성할 때 수정을 안했다면 그대로 작성.
$python label_image.py test.jpg
추론에 성공한 화면이다. 스코어가 낮은 이유는 테스트를 위해 학습을 아주 단시간 진행 했기 때문이다.
부록 : 이미지 추론 서버 만들기
http://gusrb.tistory.com/47