-
Notifications
You must be signed in to change notification settings - Fork 9
/
pickle_to_png
48 lines (37 loc) · 1.32 KB
/
pickle_to_png
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
import numpy as np
import os
import cv2
from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data
(x_train, y_train), (x_test, y_test) = load_data()
path_train = "/home/minguk/Documents/data/train"
path_test = "/home/minguk/Documents/data/test"
length_train = len(x_train)
length_test = len(x_test)
for i in range(length_train):
path_train = "/home/minguk/Documents/data/train"
path_train = path_train + "/" + str(y_train[i][0])
print(path_train)
cv2.imwrite(os.path.join(path_train, str(i) + ".png"), x_train[i])
for j in range(length_test):
path_test = "/home/minguk/Documents/data/test"
path_test = path_test + "/" + str(y_test[j][0])
print(path_test)
cv2.imwrite(os.path.join(path_test, str(j) + ".png"), x_test[j])
import shutil
import os
from glob import *
import random
data_path = "./data"
train_path = data_path + "/train"
valid_path = data_path + "/valid"
class_list = os.listdir(train_path)
num_valid = 5000
cls_valid = num_valid // len(class_list)
for i in range(len(class_list)):
img_class = glob(os.path.join(train_path,class_list[i]) + '/*.*')
random.shuffle(img_class)
img_class = img_class[0:cls_valid]
length = len(img_class)
for j in range(length):
shutil.move(img_class[j], valid_path +"/" + str(i) +"/" )