-
Notifications
You must be signed in to change notification settings - Fork 32
/
create_pascal_tf_record.py
executable file
·160 lines (120 loc) · 5.94 KB
/
create_pascal_tf_record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Converts PASCAL dataset to TFRecords file format."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import argparse
import io
import os
import sys
import PIL.Image
import tensorflow as tf
from utils import dataset_util
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='./dataset/',
help='Path to the directory containing the PASCAL VOC data.')
parser.add_argument('--output_path', type=str, default='./dataset',
help='Path to the directory to create TFRecords outputs.')
parser.add_argument('--train_data_list', type=str, default='./dataset/train.txt',
help='Path to the file listing the training data.')
parser.add_argument('--valid_data_list', type=str, default='./dataset/val.txt',
help='Path to the file listing the validation data.')
parser.add_argument('--image_data_dir', type=str, default='land_train',
help='The directory containing the image data.')
parser.add_argument('--label_data_dir', type=str, default='onechannel_label',
help='The directory containing the augmented label data.')
def dict_to_tf_example(image_path,
label_path):
"""Convert image and label to tf.Example proto.
Args:
image_path: Path to a single PASCAL image.
label_path: Path to its corresponding label.
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by image_path is not a valid JPEG or
if the label pointed to by label_path is not a valid PNG or
if the size of image does not match with that of label.
"""
with tf.gfile.GFile(image_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
with tf.gfile.GFile(label_path, 'rb') as fid:
encoded_label = fid.read()
encoded_label_io = io.BytesIO(encoded_label)
label = PIL.Image.open(encoded_label_io)
if label.format != 'PNG':
raise ValueError('Label format not PNG')
if image.size != label.size:
raise ValueError('The size of image does not match with that of label.')
width, height = image.size
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
#'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/format': dataset_util.bytes_feature('jpg'.encode('utf8')),
'label/encoded': dataset_util.bytes_feature(encoded_label),
'label/format': dataset_util.bytes_feature('png'.encode('utf8')),
}))
return example
def create_tf_record(output_filename,
image_dir,
label_dir,
examples):
"""Creates a TFRecord file from examples.
Args:
output_filename: Path to where output file is saved.
image_dir: Directory where image files are stored.
label_dir: Directory where label files are stored.
examples: Examples to parse and save to tf record.
"""
writer = tf.python_io.TFRecordWriter(output_filename)
for idx, example in enumerate(examples):
if idx % 50 == 0:
tf.logging.info('On image %d of %d', idx, len(examples))
image_path = os.path.join(image_dir, example + '_sat.jpg')
label_path = os.path.join(label_dir, example + '_label.png')
if not os.path.exists(image_path):
tf.logging.warning('Could not find %s, ignoring example.', image_path)
continue
elif not os.path.exists(label_path):
tf.logging.warning('Could not find %s, ignoring example.', label_path)
continue
try:
tf_example = dict_to_tf_example(image_path, label_path)
writer.write(tf_example.SerializeToString())
except ValueError:
tf.logging.warning('Invalid example: %s, ignoring.', example)
writer.close()
def main(unused_argv):
if not os.path.exists(FLAGS.output_path):
os.makedirs(FLAGS.output_path)
tf.logging.info("Reading from deepglobe dataset")
image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir)
label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir)
if not os.path.isdir(label_dir):
raise ValueError("Missing Augmentation label directory. "
"You may download the augmented labels from the link (Thanks to DrSleep): "
"https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip")
"""
This comment code use train.txt val.txt to find the train or value files. If you don't have this two files then use 'os.listdir' to get your filelist.
train_examples = dataset_util.read_examples_list(FLAGS.train_data_list)
val_examples = dataset_util.read_examples_list(FLAGS.valid_data_list)
"""
file_list = os.listdir(label_dir)
file_names = np.array([file.split('_')[0] for file in file_list if file.endswith('.png')], dtype=object)
val_ = np.random.choice(len(file_names), int(len(file_names) * 0.1), replace=False)
val_examples = file_names[val_]
train_examples = np.delete(file_names, val_)
train_output_path = os.path.join(FLAGS.output_path, 'voc_train.record')
val_output_path = os.path.join(FLAGS.output_path, 'voc_val.record')
create_tf_record(train_output_path, image_dir, label_dir, train_examples)
create_tf_record(val_output_path, image_dir, label_dir, val_examples)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)