-
Notifications
You must be signed in to change notification settings - Fork 1
/
func_test.py
272 lines (229 loc) · 6.83 KB
/
func_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import time
import matplotlib.pyplot as plt
import numpy as np
from numpy.lib.function_base import iterable
from numpy.lib.shape_base import row_stack
import pandas as pd
import os
from multiprocessing import Process,Array
import pynvml
from scipy.io import savemat
import datetime
def MatShift(mat,d):
"""
### 补零平移
#### mat : 二维矩阵,(w,h)
#### d : 移动距离,>0右移,<0左移
### return: 移动后的矩阵
"""
w,h = mat.shape
tmp = np.zeros_like(mat).astype(type(mat[0,0]))
if d>0:
tmp[:,d:] = mat[:,:h-d]
else:
d = np.abs(d)
tmp[:,:h-d] = mat[:,d:h]
return tmp
def ReadMat(path):
"""
### 读取单变量.mat文件
### path : .mat文件路径
### return: np.array
"""
try:
mat = io.loadmat(path)
except:
import scipy.io as io
mat = io.loadmat(path)
for i in mat.keys():
if '__' not in i:
return mat[i]
def WriteInfo_zl(path,**args):
ppp = os.path.split(path)
# 如果不存在则创建目录
if not os.path.isdir(ppp[0]):
os.makedirs(ppp[0])
# print(f"{pathDir} 创建成功")
try:
args = args['args']
except:
pass
# in any case, don't delete this code ↓
args['Time'] = [str(datetime.datetime.now())[:-7]]
try:
df = pd.read_csv(path, encoding='utf-8', engine='python')
except:
df = pd.DataFrame()
df2 = pd.DataFrame(args)
df = df.append(df2)
df.to_csv(path, index=False)
def WriteInfo(path,**args):
"""
### 写入结果至CSV文件
### path : 文件路径
### **args : 需写入的变量数据,同时以标量或列表形式传入:
write_info('./raki_result.csv',psnr =[32.2],mse = [1.54],ssim= [0.9756],mae=[0.12])
"""
isExists = os.path.exists(path)
# 判断结果
try:
args = args['args']
except:
pass
# print(args)
# assert 0
args['Time'] = [str(datetime.datetime.now())[:-7]]
try:
df = pd.read_csv(path,encoding='utf-8',engine='python')
except:
df = pd.DataFrame()
df2 = pd.DataFrame(args)
df = df.append(df2)
df.to_csv(path,index=False)
def GPUScan(memory=4000,multi=False):
"""
### Input:
memory: 显存大小(MB)\n
### Return:
id : 满足显存大小的GPU序号
"""
num = []
pynvml.nvmlInit()
deviceCount = pynvml.nvmlDeviceGetCount()
for i in range(1,deviceCount+1):
u = deviceCount - i
handle = pynvml.nvmlDeviceGetHandleByIndex(u) # 0表示第一块显卡
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
free_memory = meminfo.free/1024**2
if free_memory > memory:
num.append(u)
if num:
if not multi:
return str(num[0])
else:
return str(num)[1:-1]
print('*'*30,'No satisfied devices!','*'*30,sep='\n')
assert False
def WeightMask(shape,Rmax=0.3,sharp=0.1):
"""
### Input:
shape: 高通滤波器尺寸
Rmax : 幅值控制参数
sharp: 滤波器尖锐度,越小越尖锐
### Return:
R : 二维高通滤波器,与K空间相乘使用
"""
ny,nx = shape
num_nx = nx + nx%2
num_ny = ny + ny%2
ix = np.array([i+(-nx/2) for i in range(num_nx)])
iy = np.array([i+(-ny/2) for i in range(num_ny)])
wx = Rmax*ix/(nx/2)
wy = Rmax*iy/(ny/2)
rwx,rwy = np.meshgrid(wx,wy)
R = (rwx**2+rwy**2)**sharp
return R.astype(np.float32)
def MultiWeightMask(shape,RmaxList,SharpList):
try:
x,y,c = shape
except:
return WeightMask(shape,RmaxList,SharpList)
if c!=len(RmaxList):
print("{} not match {}!".format(c,len(RmaxList)))
assert False
item = zip(RmaxList,SharpList)
R = np.zeros(shape)
i = 0
for Rmax,Sharp in item:
if Sharp == 1:
R[...,i] = np.ones([x,y])
else:
R[...,i] = WeightMask([x,y],Rmax,Sharp)
i +=1
return R.astype(np.float32)
def DisplayBlack(data,t=3,show=True):
"""
显示K空间采样情况,默认5s
### Input:
data: 二维K空间数据
### Return:
None
"""
data = np.abs(data)
data[data>0]=255
if show:
plt.imshow(data,cmap='gray',vmin=0,vmax=255)
plt.pause(t)
else:
return data
## start VCC
def circshift(matrix_ori,shiftnum1,shiftnum2):
c,h,w = matrix_ori.shape
matrix_new=np.zeros_like(matrix_ori)
for k in range(c):
u=matrix_ori[k]
if shiftnum1 < 0:
u = np.vstack((u[-shiftnum1:,:],u[:-shiftnum1,:]))
else:
u = np.vstack((u[(h-shiftnum1):,:],u[:(h-shiftnum1),:]))
if shiftnum2 > 0:
u = np.hstack((u[:, (w - shiftnum2):], u[:, :(w - shiftnum2)]))
else:
u = np.hstack((u[:,-shiftnum2:],u[:,:-shiftnum2]))
matrix_new[k]=u
return matrix_new
def self_floor1(data1):
data=np.copy(data1)
I,J,K = data.shape
for i in range(K):
data[:,:,i] = np.flipud(data[:,:,i])#duiying de
data[:,:,i] = np.fliplr(data[:,:,i])
return data
def VCC_siganal_creation(kspace):
"""
### Input:
kspace: 三维k空间数据,(w,h,c)
### Return:
kspace_vcc : k空间数据和其共轭在通道上堆叠,前原始后共轭,(w,h,2c)
"""
nRO,nPE,nc=kspace.shape
VCC_signals=np.conj(self_floor1(np.copy(kspace)))
if np.mod(nPE,2)==0:
VCC_signals=circshift(VCC_signals,1,0)#原来是1,0
if np.mod(nRO,2)==0:
VCC_signals=circshift(VCC_signals,0,1)#原来是0,1
# kspace_vcc=np.concatenate((kspace,VCC_signals),axis=-1)
return kspace,VCC_signals
## end VCC
def multi_run(func,args,num_works=5):
result = []
for i in range(num_works):
p = Process(target=func,args=args)
result.append(p)
p.start()
time.sleep(np.random.randint(0,num_works,1)[0])
for p in result:
p.join()
def patch_rescale(img,patch_max=None,patch_min=None):
"""
### img: [patch,w,h]
"""
if not patch_max:
patch_max = np.max(np.max(img,-1),-1)[...,None,None]
patch_min = np.min(np.min(img,-1),-1)[...,None,None]
img = (img - patch_min)/(patch_max - patch_min)
img = img * 2. - 1.
return img, patch_max, patch_min
else:
img = (img - patch_min)/(patch_max - patch_min)
img = img * 2. - 1.
return img
def patch_unrescale(img,max_deg,min_deg):
if len(max_deg.shape) == 1:
max_deg = max_deg[...,None,None]
min_deg = min_deg[...,None,None]
img = (img + 1.) / 2.
img = (max_deg - min_deg) * img + min_deg
return img
# R = MultiWeightMask([188,236,32],np.array([1,2,3]),np.array([0.4,0.5,0.6]))
# write_info('./raki_result.csv',psnr =32.2,mse = 1.54,ssim= 0.9756,mae=0.12)