Skip to content

Commit

Permalink
[Add] add online data augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jeasinema committed Jan 19, 2018
1 parent 851e75a commit c928317
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 32 deletions.
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# File Name : train.py
# Purpose :
# Creation Date : 09-12-2017
# Last Modified : Sun 31 Dec 2017 06:13:27 PM CST
# Last Modified : Fri 19 Jan 2018 10:38:47 AM CST
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]

import glob
Expand Down Expand Up @@ -43,9 +43,9 @@ def main(_):
with tf.Graph().as_default():
global save_model_dir
with KittiLoader(object_dir=os.path.join(dataset_dir, 'training'), queue_size=50, require_shuffle=True,
is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT) as train_loader, \
is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=True) as train_loader, \
KittiLoader(object_dir=os.path.join(dataset_dir, 'testing'), queue_size=50, require_shuffle=True,
is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT) as valid_loader:
is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=False) as valid_loader:

gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION,
visible_device_list=cfg.GPU_AVAILABLE,
Expand Down Expand Up @@ -115,12 +115,12 @@ def main(_):

if is_summary_image:
ret = model.predict_step(
sess, valid_loader.load(), summary=True)
sess, valid_loader.load(), summary=True)
summary_writer.add_summary(ret[-1], iter)

if is_validate:
ret = model.validate_step(
sess, valid_loader.load(), summary=True)
sess, valid_loader.load(), summary=True)
summary_writer.add_summary(ret[-1], iter)

if check_if_should_pause(args.tag):
Expand Down
3 changes: 2 additions & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# File Name : __init__.py
# Purpose :
# Creation Date : 21-12-2017
# Last Modified : Thu 21 Dec 2017 11:16:46 PM CST
# Last Modified : Fri 19 Jan 2018 10:15:06 AM CST
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]

from utils.box_overlaps import *
from utils.colorize import *
from utils.kitti_loader import *
from utils.utils import *
from utils.preprocess import *
from utils.data_aug import *
35 changes: 21 additions & 14 deletions data_aug.py → utils/data_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# File Name : data_aug.py
# Purpose :
# Creation Date : 21-12-2017
# Last Modified : Mon 01 Jan 2018 09:26:32 PM CST
# Last Modified : Fri 19 Jan 2018 10:36:19 AM CST
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]

import numpy as np
Expand All @@ -14,25 +14,20 @@
import argparse
import glob

from utils import *
from utils.utils import *
from utils.preprocess import *

object_dir = './data/object'
output_path = os.path.join(object_dir, 'training_aug')

parser = argparse.ArgumentParser(description='')
parser.add_argument('-i', '--aug-amount', type=int, nargs='?', default=1000)
parser.add_argument('-n', '--num-workers', type=int, nargs='?', default=10)
args = parser.parse_args()


def worker(tag):
def aug_data(tag, object_dir):
np.random.seed()
rgb = cv2.resize(cv2.imread(os.path.join(object_dir, 'training',
rgb = cv2.resize(cv2.imread(os.path.join(object_dir,
'image_2', tag + '.png')), (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT))
lidar = np.fromfile(os.path.join(object_dir, 'training',
lidar = np.fromfile(os.path.join(object_dir,
'velodyne', tag + '.bin'), dtype=np.float32).reshape(-1, 4)
label = np.array([line for line in open(os.path.join(
object_dir, 'training', 'label_2', tag + '.txt'), 'r').readlines()]) # (N')
object_dir, 'label_2', tag + '.txt'), 'r').readlines()]) # (N')
cls = np.array([line.split()[0] for line in label]) # (N')
gt_box3d = label_to_gt_box3d(np.array(label)[np.newaxis, :], cls='', coordinate='camera')[
0] # (N', 7) x, y, z, h, w, l, r
Expand Down Expand Up @@ -88,7 +83,7 @@ def worker(tag):

gt_box3d = lidar_to_camera_box(lidar_center_gt_box3d)
newtag = 'aug_{}_1_{}'.format(
tag, np.random.randint(1, args.aug_amount))
tag, np.random.randint(1, 1024))
elif choice < 7 and choice >= 4:
# global rotation
angle = np.random.uniform(-np.pi / 4, np.pi / 4)
Expand All @@ -107,10 +102,17 @@ def worker(tag):
newtag = 'aug_{}_3_{:.4f}'.format(tag, factor).replace('.', '_')

label = box3d_to_label(gt_box3d[np.newaxis, ...], cls[np.newaxis, ...], coordinate='camera')[0] # (N')
voxel_dict = process_pointcloud(lidar)
return newtag, rgb, lidar, voxel_dict, label


def worker(tag):
new_tag, rgb, lidar, voxel_dict, label = aug_data(tag)
output_path = os.path.join(object_dir, 'training_aug')

cv2.imwrite(os.path.join(output_path, 'image_2', newtag + '.png'), rgb)
lidar.reshape(-1).tofile(os.path.join(output_path,
'velodyne', newtag + '.bin'))
voxel_dict = process_pointcloud(lidar)
np.savez_compressed(os.path.join(
output_path, 'voxel' if cfg.DETECT_OBJ == 'Car' else 'voxel_ped', newtag), **voxel_dict)
with open(os.path.join(output_path, 'label_2', newtag + '.txt'), 'w+') as f:
Expand All @@ -131,4 +133,9 @@ def main():


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('-i', '--aug-amount', type=int, nargs='?', default=1000)
parser.add_argument('-n', '--num-workers', type=int, nargs='?', default=10)
args = parser.parse_args()

main()
34 changes: 22 additions & 12 deletions utils/kitti_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# File Name : kitti_loader.py
# Purpose :
# Creation Date : 09-12-2017
# Last Modified : Fri 05 Jan 2018 09:32:43 PM CST
# Last Modified : Fri 19 Jan 2018 10:34:56 AM CST
# Created By : Jeasine Ma [jeasinema[at]gmail[dot]com]

import cv2
Expand All @@ -20,6 +20,7 @@
from multiprocessing import Lock, Process, Queue as Queue, Value, Array, cpu_count

from config import cfg
from utils.data_aug import aug_data

# for non-raw dataset

Expand All @@ -35,7 +36,7 @@ class KittiLoader(object):
# vox_number
# vox_coordinate

def __init__(self, object_dir='.', queue_size=20, require_shuffle=False, is_testset=True, batch_size=1, use_multi_process_num=0, split_file='', multi_gpu_sum=1):
def __init__(self, object_dir='.', queue_size=20, require_shuffle=False, is_testset=True, batch_size=1, use_multi_process_num=0, split_file='', multi_gpu_sum=1, aug=False):
assert(use_multi_process_num >= 0)
self.object_dir = object_dir
self.is_testset = is_testset
Expand All @@ -44,6 +45,7 @@ def __init__(self, object_dir='.', queue_size=20, require_shuffle=False, is_test
self.batch_size = batch_size
self.split_file = split_file
self.multi_gpu_sum = multi_gpu_sum
self.aug = aug

if self.split_file != '':
# use split file
Expand Down Expand Up @@ -129,17 +131,25 @@ def fill_queue(self, batch_size=0):
labels, tag, voxel, rgb, raw_lidar = [], [], [], [], []
for _ in range(batch_size):
try:
rgb.append(cv2.resize(cv2.imread(
self.f_rgb[load_index]), (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT)))
raw_lidar.append(np.fromfile(
self.f_lidar[load_index], dtype=np.float32).reshape((-1, 4)))
if not self.is_testset:
labels.append([line for line in open(
self.f_label[load_index], 'r').readlines()])
if self.aug:
ret = aug_data(self.data_tag[load_index], self.object_dir)
tag.append(ret[0])
rgb.append(ret[1])
raw_lidar.append(ret[2])
voxel.append(ret[3])
labels.append(ret[4])
else:
labels.append([''])
tag.append(self.data_tag[load_index])
voxel.append(np.load(self.f_voxel[load_index]))
rgb.append(cv2.resize(cv2.imread(
self.f_rgb[load_index]), (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT)))
raw_lidar.append(np.fromfile(
self.f_lidar[load_index], dtype=np.float32).reshape((-1, 4)))
if not self.is_testset:
labels.append([line for line in open(
self.f_label[load_index], 'r').readlines()])
else:
labels.append([''])
tag.append(self.data_tag[load_index])
voxel.append(np.load(self.f_voxel[load_index]))

load_index += 1
except:
Expand Down

0 comments on commit c928317

Please sign in to comment.