Skip to content

Commit

Permalink
feat: adapted for testing estimated resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
mariajmellado committed Aug 6, 2024
1 parent 47dac1d commit 902a85f
Show file tree
Hide file tree
Showing 4 changed files with 899 additions and 47 deletions.
821 changes: 821 additions & 0 deletions frank/Examples/13_July_ResolutionTest.ipynb

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions frank/fourier2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def __init__(self, Rmax, N, nu=0):
self.dy = 2*self.Ymax/self.Ny

# Frequency space collocation points.
# The [1:] is because to not consider the 0 baseline. But we're missing points.
q1n = np.fft.fftfreq(self.Nx, d = self.dx)
q2n = np.fft.fftfreq(self.Ny, d = self.dy)
q1n, q2n = np.meshgrid(q1n, q2n, indexing='ij')
Expand Down Expand Up @@ -53,11 +52,10 @@ def coefficients(self, u = None, v = None, x = None, y = None, direction="forwar
norm = 1 / (4*self.Xmax*self.Ymax)
factor = 2j*np.pi

X, Y = u, v
X, Y = self.Un, self.Vn
if u is None:
X, Y = self.Un, self.Vn
u = self.Xn
v = self.Yn
u = self.Xn
v = self.Yn
else:
raise AttributeError("direction must be one of {}"
"".format(['forward', 'backward']))
Expand Down Expand Up @@ -94,4 +92,9 @@ def Rmax(self):
@property
def resolution(self):
""" Resolution of the grid in the x coordinate in rad"""
return self.dx
return self.dx

@property
def xy_points(self):
""" Collocation points in the image plane"""
return self.Xn, self.Yn
6 changes: 3 additions & 3 deletions frank/radial_fitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def interpolate_brightness(self, Rpts, I=None):
-----
The resulting brightness will be consistent with higher-resolution fits
as long as the original fit has sufficient resolution. By sufficient
resolution we simply mean that the missing terms in the Fourier-Bessel
we simply mean that the missing terms in the Fourier-Bessel
series are negligible, which will typically be the case if the
brightness was obtained from a frank fit with 100 points or more.
"""
Expand Down Expand Up @@ -437,7 +437,7 @@ class FourierBesselFitter(object):

def __init__(self, Rmax, N, geometry=None, nu=0, block_data=True,
assume_optically_thick=True, scale_height=None,
block_size=10 ** 5, verbose=True):
block_size=10 ** 5, verbose=True, geometry_on = True):

Rmax /= rad_to_arcsec

Expand All @@ -457,7 +457,7 @@ def __init__(self, Rmax, N, geometry=None, nu=0, block_data=True,
model = 'opt_thin'

self._vis_map = VisibilityMapping(self._DHT, geometry,
model, scale_height=scale_height,
model, geometry_on = geometry_on ,scale_height=scale_height,
block_data=block_data, block_size=block_size,
check_qbounds=False, verbose=verbose,
DFT = self._DFT)
Expand Down
104 changes: 66 additions & 38 deletions frank/statistical_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class VisibilityMapping:
Whether to print notification messages
"""
def __init__(self, DHT, geometry,
vis_model='opt_thick', scale_height=None, block_data=True,
vis_model='opt_thick', geometry_on = True, scale_height=None, block_data=True,
block_size=10 ** 5, check_qbounds=True, verbose=True,
DFT = None):

Expand All @@ -78,6 +78,7 @@ def __init__(self, DHT, geometry,
self._vis_model = vis_model
self.check_qbounds = check_qbounds
self._verbose = verbose
self.geometry_on = geometry_on

self._chunking = block_data
self._chunk_size = block_size
Expand Down Expand Up @@ -165,7 +166,8 @@ def map_visibilities(self, u, v, V, weights, frequencies=None, geometry=None):
logging.info(' Building visibility matrices M and j')

# Deproject the visibilities
#u, v, k, V = self._geometry.apply_correction(u, v, V, use3D=True)
if self.geometry_on:
u, v, k, V = self._geometry.apply_correction(u, v, V, use3D=True)
q = np.hypot(u, v)

# Check consistency of the uv points with the model
Expand All @@ -179,7 +181,7 @@ def map_visibilities(self, u, v, V, weights, frequencies=None, geometry=None):
if frequencies is None:
multi_freq = False
frequencies = np.ones_like(V)

start_time = time.time()
channels = np.unique(frequencies)
Ms = np.zeros([len(channels), self.size, self.size], dtype='c8')
js = np.zeros([len(channels), self.size], dtype='c8')
Expand Down Expand Up @@ -209,15 +211,34 @@ def map_visibilities(self, u, v, V, weights, frequencies=None, geometry=None):
Vs = Vi[start:end]

X = self._get_mapping_coefficients(qs, ks, us, vs)
X_CT = self._get_mapping_coefficients(qs, ks, us, vs, inverse=True)
wXT = np.matmul(X_CT, np.diag(ws)) # this line is the same as below.
#wXT = np.matmul(np.transpose(np.conjugate(X)), np.diag(ws), dtype = "complex64")
Ms[i] += np.matmul(wXT, X, dtype="complex128")
js[i] += np.matmul(wXT, Vs, dtype="complex128")
wXT = np.matmul(np.transpose(np.conjugate(X)), np.diag(ws), dtype = "complex128")
Ms[i] += np.matmul(wXT, X, dtype="complex128").real
js[i] += np.matmul(wXT, Vs, dtype="complex128").real

start = end
end = min(Ndata, end + Nstep)


import matplotlib.pyplot as plt


N = int(np.sqrt(self._DFT.size))
r"""FRANK 2D: TESTING M
sparcity = ((np.sum(np.abs(Ms[0]) < 0.5e-17))/N**4)*100
print("M is sparse? ", sparcity)
print(Ms[0])
plt.imshow(Ms[0].real, cmap="magma", vmax = np.max(Ms[0].real), vmin = np.mean(Ms[0].real))
plt.colorbar()
plt.show()
M = np.diag(np.diag(Ms[0].real))
num_zeros = ((np.sum(np.abs(M) == 0.0))/N**4)*100
print("M is sparse? ", num_zeros)
IFT2 = np.linalg.solve(M, js[0]).real
plt.imshow(IFT2.reshape(N,N).T, cmap="magma")
plt.colorbar()
print("M type", Ms[0].dtype)
print("j type", js[0].dtype)
print("M imag-> ", " max: ", np.max(Ms[0].imag), ", min: ", np.min(Ms[0].imag) , ", mean: ", np.mean(Ms[0].imag) ,", median: ", np.median(Ms[0].imag), ", std: ", np.std(Ms[0].imag))
Expand All @@ -228,19 +249,22 @@ def map_visibilities(self, u, v, V, weights, frequencies=None, geometry=None):
# from scipy.linalg import issymmetric
# print(issymmetric(Ms[0]), "that M is a Symmetric Matrix")


import matplotlib.pyplot as plt
plt.matshow(Ms[0].real, cmap="magma", vmax = np.max(Ms[0].real), vmin = np.mean(Ms[0].real))
plt.colorbar()
plt.title("M matrix, real part")
plt.show()
#import sys
#sys.exit()

"""

#Ms[0] = np.loadtxt(r'.\..\Notebooks\M_N75.txt', dtype = 'c8')
#js[0] = np.loadtxt(r'.\..\Notebooks\j_N75.txt', dtype = 'c8')
print("--- %s minutes to calculate M and j ---" % (time.time()/60 - start_time/60))
path = r'/Users/mariajmelladot/Desktop/Frank2D/1_Frank2D_DEV/data/TestingComplexity/'
np.save(path + 'M_N' + str(N) , Ms[0].real)
np.save(path + 'j_N' + str(N) , js[0].real)


# Compute likelihood normalization H_0, i.e., the
# log-likelihood of a source with I=0.
Expand Down Expand Up @@ -288,10 +312,10 @@ def check_hash(self, hash, multi_freq=False, geometry=None):
self._DFT.Rmax == hash[1].Rmax and
self._DFT.size == hash[1].size and
self._DFT.order == hash[1].order and
geometry.inc == hash[2].inc and
geometry.PA == hash[2].PA and
geometry.dRA == hash[2].dRA and
geometry.dDec == hash[2].dDec and
#geometry.inc == hash[2].inc and
#geometry.PA == hash[2].PA and
#geometry.dRA == hash[2].dRA and
#geometry.dDec == hash[2].dDec and
self._vis_model == hash[3]
)

Expand Down Expand Up @@ -496,8 +520,11 @@ def _get_mapping_coefficients(self, qs, ks, u, v, geometry=None, inverse=False):
if self._vis_model == 'opt_thick':
# Optically thick & geometrically thin
if geometry is None:
geometry = self._geometry
scale = np.cos(geometry.inc * deg_to_rad)
if not self.geometry_on:
scale = 1
else:
geometry = self._geometry
scale = np.cos(geometry.inc * deg_to_rad)
elif self._vis_model == 'opt_thin':
# Optically thin & geometrically thin
scale = 1
Expand Down Expand Up @@ -746,7 +773,6 @@ def __init__(self, DHT, M, j, p=None, scale=None, guess=None,
" New GP "
self.u, self.v = self._DFT.uv_points
self.Ykm = self._DFT.coefficients(direction="backward")
self.Ykm_f = self._DFT.coefficients(direction="forward")

m, c , l = -5, 60, 1e4
#m, c, l = 0.23651345032212925, 60.28747193555951, 1.000389e+05
Expand All @@ -758,10 +784,29 @@ def __init__(self, DHT, M, j, p=None, scale=None, guess=None,
S_real_inv = np.linalg.inv(S_real)
print("--- %s minutes to calculate S_real_inv---" % (time.time()/60 - start_time/60))
self._Sinv = S_real_inv
print(self._Sinv.dtype, " Sinv dtype")
start_time = time.time()
self._fit()
print("--- %s minutes to fit---" % (time.time()/60 - start_time/60))

def calculate_S_real(self, u, v, l, m, c):
start_time = time.time()
S_fspace = self.true_squared_exponential_kernel(u, v, l, m, c)
print("--- %s minutes to calculate S---" % (time.time()/60 - start_time/60))
start_time = time.time()
S_real = np.matmul(self.Ykm, np.matmul(S_fspace, self.Ykm.conj()), dtype = "complex128").real
print("--- %s minutes to calculate S_real---" % (time.time()/60 - start_time/60))

"""
#FRANK 2D: TESTING S_real
print(" S_real")
import matplotlib.pyplot as plt
plt.matshow(S_real, cmap="magma")
plt.colorbar()
plt.title("S matrix, real part ")
plt.show()
"""

return S_real

def true_squared_exponential_kernel(self, u, v, l, m, c):
u1, u2 = np.meshgrid(u, u)
Expand All @@ -784,23 +829,6 @@ def power_spectrum(q, m, c):
SE_Kernel = np.sqrt(p1 * p2) * np.exp(-0.5*((u1-u2)**2 + (v1-v2)**2)/ l**2)
return SE_Kernel

def calculate_S_real(self, u, v, l, m, c):
start_time = time.time()
S_fspace = self.true_squared_exponential_kernel(u, v, l, m, c)
print("--- %s minutes to calculate S---" % (time.time()/60 - start_time/60))
start_time = time.time()
S_real = np.matmul(self.Ykm, np.matmul(S_fspace, self.Ykm_f), dtype = "complex64")
print("--- %s minutes to calculate S_real---" % (time.time()/60 - start_time/60))

print(S_real.dtype, " S_real dtype")
import matplotlib.pyplot as plt
plt.matshow(S_fspace, cmap="magma", vmin = 0, vmax = 1e-3)
plt.colorbar()
plt.title("S real matrix, real part ")
plt.show()

return S_real

def calculate_mu_cholesky(self, Dinv):
print("calculate mu with cholesky")
try:
Expand Down Expand Up @@ -889,7 +917,7 @@ def _fit(self):

Dinv = self._M + Sinv

"""
r""" FRANK 2D: TESTING Dinv
#import scipy.linalg as sc
def is_pos_def(x):
return np.all(np.linalg.eigvals(x) > 0)
Expand Down

0 comments on commit 902a85f

Please sign in to comment.