-
Notifications
You must be signed in to change notification settings - Fork 4
/
demo_tga.py
76 lines (61 loc) · 2.06 KB
/
demo_tga.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Modified from https://github.com/dfm/pcp/blob/master/demo.py
Uses data from http://perception.i2r.a-star.edu.sg/bk_model/bk_index.html
"""
from __future__ import division, print_function
import os
import time
import numpy as np
from PIL import Image
from tga import TGA
def bitmap_to_mat(bitmap_seq):
"""from blog.shriphani.com"""
matrix = []
shape = None
for bitmap_file in bitmap_seq:
img = Image.open(bitmap_file).convert("L")
if shape is None:
shape = img.size
assert img.size == shape
img = np.array(img.getdata())
matrix.append(img)
return np.array(matrix), shape[::-1]
def do_plot(ax, img, shape):
ax.cla()
ax.imshow(img.reshape(shape), cmap="gray", interpolation="nearest")
ax.set_xticklabels([])
ax.set_yticklabels([])
if __name__ == "__main__":
import sys
import glob
import matplotlib.pyplot as pl
use_data = "/home/vighnesh/images/Escalator/"
M, shape = bitmap_to_mat(glob.glob(use_data + "/*.bmp")[:2000:2])
print(M.shape)
tga = TGA(n_components=5, random_state=1, tol=1e-3)
start_time = time.time()
tga.fit(M)
print("fitted, time taken {0}s".format(time.time() - start_time))
start_time = time.time()
transformed = tga.transform(M)
L = tga.inverse_transform(transformed)
print('calculated L, time taken {0}s'.format(time.time() - start_time))
S = M - L
if not os.path.exists('results_tga'):
os.makedirs('results_tga')
directory = "results_tga/" + use_data
if not os.path.exists(directory):
os.makedirs(directory)
fig, axes = pl.subplots(1, 3, figsize=(10, 4))
fig.subplots_adjust(left=0, right=1, hspace=0, wspace=0.01)
i = 0
# for i in range(min(len(M), 500)):
do_plot(axes[0], M[i], shape)
axes[0].set_title("raw")
do_plot(axes[1], L[i], shape)
axes[1].set_title("low rank")
do_plot(axes[2], S[i], shape)
axes[2].set_title("sparse")
#fig.savefig("results_tga/" + use_data + "/{0:05d}.png".format(i))
pl.show()