-
Notifications
You must be signed in to change notification settings - Fork 1
/
example_window_inference.py
87 lines (79 loc) · 3.32 KB
/
example_window_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#
# Author: Georg Zitzlsberger (georg.zitzlsberger<ad>vsb.cz)
# Copyright (C) 2020-2022 Georg Zitzlsberger, IT4Innovations,
# VSB-Technical University of Ostrava, Czech Republic
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import csv
import time
import math
import datetime
import numpy as np
from datetime import datetime
import sys
sys.path.append("../lib/")
import rsdtlib
import multiprocessing
n_threads = 2
tf_record_path = "./tf_stack/"
tf_record_out_path = "./tf_window/infer/"
os.mkdirs(tf_record_out_path, exist_ok=True)
# Define the window parameters.
# Note: This is not yet processing!
window = rsdtlib.Window(
tf_record_path,
60*60*24*30, # Delta (size)
1, # window stride
10, # omega (min. window size)
16, # Omega (max. window size)
False, # generate triplet
n_threads=n_threads, # number of threads to use
use_new_save=False) # new TF Dataset save
def write_it(args):
location = args[0]
tile = args[1]
window.write_tf_files(location,
lambda j, i: (j==tile[0] and i==tile[1]))
return
if __name__ == "__main__":
# Write the identified windows to a CSV files.
window_list = window.windows_list()
with open(tf_record_out_path + "windows_inference.csv",
mode = "w") as csv_file:
csv_writer = csv.writer(csv_file,
delimiter=",",
quotechar="\"",
quoting=csv.QUOTE_MINIMAL)
for item in window_list:
csv_writer.writerow([item[0],
datetime.utcfromtimestamp(item[1]),
datetime.utcfromtimestamp(item[2]),
item[3]])
# Write the final inference samples (windows without labels). The selector
# function specifies all tiles to consider for inference samples.
list_tiles = []
selector = lambda j, i: True
num_tiles_y, num_tiles_x = window.get_num_tiles()
for j in range(0, num_tiles_y):
for i in range(0, num_tiles_x):
if selector(j, i):
list_tiles.append((tf_record_out_path, (j, i)))
with multiprocessing.get_context("spawn").Pool(processes=n_threads) as p:
for i, _ in enumerate(p.imap_unordered(
write_it,
list_tiles)):
sys.stdout.write("\r Progress: {0:.1%}".format(i/len(list_tiles)))
sys.stdout.flush()
print("\n")