Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update subset-mesh capability of the subset_data tool with nco/esmf two-line approach #1884

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
248 changes: 248 additions & 0 deletions python/ctsm/site_and_regional/regional_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
# -- Import Python Standard Libraries
import logging
import os
import argparse

# -- 3rd party libraries
import numpy as np
import xarray as xr
from tqdm import tqdm
from datetime import datetime

# -- import local classes for this script
from ctsm.site_and_regional.base_case import BaseCase, USRDAT_DIR
Expand Down Expand Up @@ -49,6 +53,15 @@ class RegionalCase(BaseCase):

Methods
-------
check_region_bounds
Check for the regional bounds

check_region_lons
Check for the regional lons

check_region_lats
Check for the regional lats

create_tag
Create a tag for this region which is either
region's name or a combination of bounds of this
Expand Down Expand Up @@ -77,6 +90,7 @@ def __init__(
create_landuse,
create_datm,
create_user_mods,
create_mesh,
out_dir,
overwrite,
):
Expand All @@ -96,7 +110,9 @@ def __init__(
self.lon1 = lon1
self.lon2 = lon2
self.reg_name = reg_name
self.create_mesh = create_mesh
self.out_dir = out_dir
self.check_region_bounds()
self.create_tag()

def create_tag(self):
Expand All @@ -112,6 +128,50 @@ def create_tag(self):
str(self.lon1), str(self.lon2), str(self.lat1), str(self.lat2)
)

def check_region_bounds (self):
"""
Check for the regional bounds
"""
self.check_region_lons()
self.check_region_lats()

def check_region_lons (self):
"""
Check for the regional lon bounds
"""
if self.lon1 >= self.lon2:
err_msg = """
\n
ERROR: lon1 is bigger than lon2.
lon1 points to the westernmost longitude of the region. {}
lon2 points to the easternmost longitude of the region. {}
Please make sure lon1 is smaller than lon2.

Please note that if longitude in -180-0, the code automatically
convert it to 0-360.
""".format(
self.lon1, self.lon2
)
raise argparse.ArgumentTypeError(err_msg)

def check_region_lats (self):
"""
Check for the regional lat bound
"""
if self.lat1 >= self.lat2:
err_msg = """
\n
ERROR: lat1 is bigger than lat2.
lat1 points to the westernmost longitude of the region. {}
lat2 points to the easternmost longitude of the region. {}
Please make sure lat1 is smaller than lat2.

""".format(
self.lat1, self.lat2
)
raise argparse.ArgumentTypeError(err_msg)


def create_domain_at_reg(self, indir, file):
"""
Create domain file for this RegionalCase class.
Expand Down Expand Up @@ -224,3 +284,191 @@ def create_landuse_at_reg(self, indir, file, user_mods_dir):
# line = "landuse = '${}'".format(os.path.join(USRDAT_DIR, fluse_out))
line = "flanduse_timeseries = '${}'".format(os.path.join(USRDAT_DIR, fluse_out))
self.write_to_file(line, nl_clm)


def create_mesh_at_reg(self, mesh_dir, mesh_surf):
"""
Create a mesh subsetted for the RegionalCase class.
"""
logger.info(
"----------------------------------------------------------------------"
)
logger.info(
"Subsetting mesh file for region: %s",
self.tag
)

today = datetime.today()
today_string = today.strftime("%y%m%d")


mesh_in = os.path.join(mesh_dir, mesh_surf)
mesh_out = os.path.join(self.out_dir, os.path.splitext(mesh_surf)[0]+'_'+self.tag+'_c'+today_string+'.nc')

logger.info("mesh_in : %s", mesh_in)
logger.info("mesh_out : %s", mesh_out)

self.mesh = mesh_out

node_coords, subset_element, subset_node, conn_dict = self.subset_mesh_at_reg(mesh_in)

f_in = xr.open_dataset (mesh_in)
self.write_mesh (f_in, node_coords, subset_element, subset_node, conn_dict, mesh_out)


def subset_mesh_at_reg (self, mesh_in):
"""
This function subsets the mesh based on lat and lon bounds given by RegionalCase class.
"""
f_in = xr.open_dataset (mesh_in)
elem_count = len (f_in['elementCount'])
elem_conn = f_in['elementConn']
num_elem_conn = f_in['numElementConn']
center_coords = f_in['centerCoords']
node_count = len (f_in['nodeCount'])
node_coords = f_in['nodeCoords']

subset_element = []
cnt = 0

for n in tqdm(range(elem_count)):
endx = elem_conn[n,:num_elem_conn[n].values].values
endx[:,] -= 1# convert to zero based index
endx = [int(xi) for xi in endx]

nlon = node_coords[endx,0].values
nlat = node_coords[endx,1].values

l1 = np.logical_or(nlon <= self.lon1,nlon >= self.lon2)
l2 = np.logical_or(nlat <= self.lat1,nlat >= self.lat2)
if np.any(np.logical_or(l1,l2)):
pass
else:
subset_element.append(n)
cnt+=1


subset_node = []
conn_dict = {}
cnt = 1

for n in range(node_count):
nlon = node_coords[n,0].values
nlat = node_coords[n,1].values

l1 = np.logical_or(nlon <= self.lon1,nlon >= self.lon2)
l2 = np.logical_or(nlat <= self.lat1,nlat >= self.lat2)
if np.logical_or(l1,l2):
conn_dict[n+1] = -9999
else:
subset_node.append(n)
conn_dict[n+1] = cnt
cnt+=1
return node_coords, subset_element, subset_node, conn_dict



def write_mesh (self, f_in, node_coords, subset_element, subset_node, conn_dict, mesh_out):
"""
This function writes out the subsetted mesh file.
"""
corner_pairs = f_in.variables['nodeCoords'][subset_node,]

dimensions = f_in.dims
variables = f_in.variables
global_attributes = f_in.attrs


max_node_dim = len(f_in['maxNodePElement'])

elem_count = len(subset_element)
elem_conn_out = np.empty(shape=[elem_count, max_node_dim])
elem_conn_index = f_in.variables['elementConn'][subset_element,]

for ni in range(elem_count):
for mi in range(max_node_dim):
ndx = int (elem_conn_index[ni,mi])
elem_conn_out[ni,mi] = conn_dict[ndx]


num_elem_conn_out = np.empty(shape=[elem_count,])
num_elem_conn_out[:] = f_in.variables['numElementConn'][subset_element,]

center_coords_out = np.empty(shape=[elem_count,2])
center_coords_out[:,:]=f_in.variables['centerCoords'][subset_element,:]

if 'elementMask' in variables:
elem_mask_out = np.empty(shape=[elem_count,])
elem_mask_out[:]=f_in.variables['elementMask'][subset_element,]

if 'elementArea' in variables:
elem_area_out = np.empty(shape=[elem_count,])
elem_area_out[:]=f_in.variables['elementArea'][subset_element,]

# -- create output dataset
f_out = xr.Dataset()

f_out['nodeCoords'] = xr.DataArray(corner_pairs,
dims=('nodeCount', 'coordDim'),
attrs={'units': 'degrees'})

f_out['elementConn'] = xr.DataArray(elem_conn_out,
dims=('elementCount', 'maxNodePElement'),
attrs={'long_name': 'Node indices that define the element connectivity'})
f_out.elementConn.encoding = {'dtype': np.int32}

f_out['numElementConn'] = xr.DataArray(num_elem_conn_out,
dims=('elementCount'),
attrs={'long_name': 'Number of nodes per element'})

f_out['centerCoords'] = xr.DataArray(center_coords_out,
dims=('elementCount', 'coordDim'),
attrs={'units': 'degrees'})


#-- add mask
if 'elementMask' in variables:
f_out['elementMask'] = xr.DataArray(elem_mask_out,
dims=('elementCount'),
attrs={'units': 'unitless'})
f_out.elementMask.encoding = {'dtype': np.int32}

if 'elementArea' in variables:
f_out['elementArea'] = xr.DataArray(elem_area_out,
dims=('elementCount'),
attrs={'units': 'unitless'})

#-- setting fill values
for var in variables:
if '_FillValue' in f_in[var].encoding:
f_out[var].encoding['_FillValue'] = f_in[var].encoding['_FillValue']
else:
f_out[var].encoding['_FillValue'] = None

#-- add global attributes
for attr in global_attributes:
if attr != 'timeGenerated':
f_out.attrs[attr] = global_attributes[attr]

f_out.attrs = {'title': 'ESMF unstructured grid file for a region',
'created_by': 'subset_data',
'date_created': '{}'.format(datetime.now()),
}

f_out.to_netcdf(mesh_out)
logger.info("Successfully created file (mesh_out) %s", mesh_out)

def write_shell_commands(self, namelist):
"""
writes out xml commands commands to a file (i.e. shell_commands) for single-point runs
"""
# write_to_file surrounds text with newlines
with open(namelist, "w") as nl_file:
self.write_to_file(
"# Change below line if you move the subset data directory", nl_file
)
self.write_to_file(
"./xmlchange {}={}".format(USRDAT_DIR, self.out_dir), nl_file
)
self.write_to_file("./xmlchange ATM_DOMAIN_MESH={}".format(str(self.mesh)), nl_file)
self.write_to_file("./xmlchange LND_DOMAIN_MESH={}".format(str(self.mesh)), nl_file)
Loading