Skip to content

Commit

Permalink
refactor atlas attribute names
Browse files Browse the repository at this point in the history
  • Loading branch information
jveitchmichaelis committed Aug 12, 2023
1 parent 1b51b57 commit e3a2a66
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 180 deletions.
5 changes: 3 additions & 2 deletions examples/example_lt_sprat_manual_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
from astropy.io import fits
from matplotlib import pyplot as plt
from scipy.signal import find_peaks

from rascal import util
from rascal.atlas import Atlas
from rascal.calibrator import Calibrator
from scipy.signal import find_peaks

# Load the LT SPRAT data
base_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -109,7 +110,7 @@
}

atlas = Atlas(
line_list="manual",
source="manual",
wavelengths=sprat_atlas_lines,
min_wavelength=3800.0,
max_wavelength=8000.0,
Expand Down
48 changes: 16 additions & 32 deletions src/rascal/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
logger = logging.getLogger(__name__)


class LineSource(Enum):
class NistSource(Enum):
NIST_STRONG = auto()
NIST_ALL = auto()

Expand Down Expand Up @@ -76,8 +76,8 @@ class Atlas:

min_wavelength: float = MISSING
max_wavelength: float = MISSING
elements: Optional[Any] = None
line_list: str = "nist"
element: Optional[Any] = None
source: str = "nist"
range_tolerance: float = 0.0
min_intensity: Optional[float] = 10.0
min_distance: float = 10.0
Expand All @@ -90,27 +90,24 @@ class Atlas:
intensities: Optional[List[float]] = field(default=None)
use_accurate_lines: Optional[bool] = True
atlas_lines: Optional[List[AtlasLine]] = field(default=None)
nist_source: LineSource = LineSource.NIST_STRONG
nist_source: NistSource = NistSource.NIST_STRONG

def __post_init__(self):

if isinstance(self.elements, str):
self.elements = [self.elements]

self.min_atlas_wavelength = self.min_wavelength - self.range_tolerance
self.max_atlas_wavelength = self.max_wavelength + self.range_tolerance
self.atlas_lines = []

logger.info(
f"Loading lines from {self.line_list} list between {self.min_atlas_wavelength} and {self.max_atlas_wavelength} for elements: {set(self.elements)}"
f"Loading lines from {self.source} between {self.min_atlas_wavelength} and {self.max_atlas_wavelength} for element: {set(self.element)}"
)
logger.info(
f"Filtering lines by intensity > {self.min_intensity} and separation > {self.min_distance} Å"
)

if self.line_list == "manual":
if self.source == "manual":
self.add_manual()
elif self.line_list == "nist":
elif self.source == "nist":
self.add_nist()
else:
raise NotImplementedError
Expand All @@ -121,24 +118,12 @@ def add_manual(self):
if not isinstance(self.wavelengths, list):
self.wavelengths = list(self.wavelengths)

# If a single element is provided, assume that
# all lines are from this element
if len(self.elements) == 1:
self.elements = self.elements * len(self.wavelengths)

# Empty intensity
if self.intensities is None:
self.intensities = [0] * len(self.wavelengths)
elif not isinstance(self.intensities, list):
self.intensities = list(self.intensities)

assert len(self.elements) == len(self.wavelengths), ValueError(
"Input elements and wavelengths have different length."
)
assert len(self.elements) == len(self.intensities), ValueError(
"Input elements and intensities have different length."
)

if self.vacuum:

self.wavelengths = vacuum_to_air_wavelength(
Expand All @@ -148,8 +133,8 @@ def add_manual(self):
self.relative_humidity,
)

for element, wavelength, intensity in list(
zip(self.elements, self.wavelengths, self.intensities)
for wavelength, intensity in list(
zip(self.wavelengths, self.intensities)
):
if wavelength < (self.min_wavelength - self.range_tolerance):
logger.warning(
Expand All @@ -170,16 +155,15 @@ def add_manual(self):
self.atlas_lines.append(
AtlasLine(
wavelength=wavelength,
element=element,
element=self.element,
intensity=intensity,
source="user",
)
)

def add_nist(self):
assert len(self.elements) == 1

s = self.elements[0].split(" ")
s = self.element.split(" ")

if len(s) == 2:
states = s[1]
Expand Down Expand Up @@ -458,7 +442,7 @@ def wavelengths(self):
def nist_files(
element: str,
states: List[str] = ["I", "II"],
source: LineSource = LineSource.NIST_STRONG,
source: NistSource = NistSource.NIST_STRONG,
) -> List[str]:
"""
Locate atlas files for a particular element and an optional state. By default, only the I and II
Expand Down Expand Up @@ -486,9 +470,9 @@ def nist_files(

get_ref = lambda path: (import_resources.files(__package__) / path)

if source == LineSource.NIST_ALL:
if source == NistSource.NIST_ALL:
root = f"arc_lines/nist_clean_"
elif source == LineSource.NIST_STRONG:
elif source == NistSource.NIST_STRONG:
root = f"arc_lines/strong_lines/"
else:
raise NotImplementedError(
Expand Down Expand Up @@ -573,7 +557,7 @@ def open_line_list(
def nist_lines(
element,
states=["I", "II"],
source=LineSource.NIST_STRONG,
source=NistSource.NIST_STRONG,
only_accurate=True,
) -> List[Dict]:
"""Load NIST reference lines for the specified element and ionisation states.
Expand All @@ -596,7 +580,7 @@ def nist_lines(
"""
files = nist_files(element, states, source)

if (source == LineSource.NIST_STRONG) and only_accurate:
if (source == NistSource.NIST_STRONG) and only_accurate:
only_accurate = False
logger.debug(
"Disabling accurate line filter as NIST strong lines do not have an associated accuracy."
Expand Down
48 changes: 25 additions & 23 deletions test/test_atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_load_nist_single(elements):
time_start = time.perf_counter()

nist_atlas = Atlas(
elements=[element],
line_list="nist",
element=element,
source="nist",
min_wavelength=min_wavelength,
max_wavelength=max_wavelength,
brightest_n_lines=None,
Expand All @@ -63,8 +63,8 @@ def test_load_with_min_intensity():
min_intensity = 10

nist_atlas = Atlas(
elements=["Xe"],
line_list="nist",
element="Xe",
source="nist",
min_wavelength=4000,
max_wavelength=8000,
min_intensity=min_intensity,
Expand All @@ -75,24 +75,26 @@ def test_load_with_min_intensity():
assert i >= min_intensity, "Line intensity is below minimum"


"""
# TODO: Move to Atlas collection
def test_load_nist_all(elements):
collection = AtlasCo
nist_atlas = Atlas(
elements=elements,
element=elements,
line_list="nist",
min_wavelength=0,
max_wavelength=8000,
)

assert len(nist_atlas) > 0
"""


def test_check_nonzero_length_common_elements_single():

for element in ["Xe", "Kr", "Ar", "Ne", "Cu"]:
nist_atlas = Atlas(
elements=element,
line_list="nist",
element=element,
source="nist",
min_wavelength=4000,
max_wavelength=8000,
)
Expand All @@ -102,8 +104,8 @@ def test_check_nonzero_length_common_elements_single():

def test_load_single_line():
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=[5.0],
min_wavelength=0,
max_wavelength=10,
Expand All @@ -113,8 +115,8 @@ def test_load_single_line():

def test_load_mutliple_lines():
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=np.arange(10),
min_wavelength=0,
max_wavelength=10,
Expand All @@ -124,8 +126,8 @@ def test_load_mutliple_lines():

def test_setting_a_known_pair():
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=np.arange(10),
min_wavelength=0,
max_wavelength=10,
Expand All @@ -139,8 +141,8 @@ def test_setting_a_known_pair():

def test_setting_known_pairs():
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=np.arange(10),
min_wavelength=0,
max_wavelength=10,
Expand All @@ -155,8 +157,8 @@ def test_setting_known_pairs():
@pytest.mark.xfail()
def test_setting_a_none_to_known_pairs_expect_fail():
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=np.arange(10),
min_wavelength=0,
max_wavelength=10,
Expand All @@ -169,8 +171,8 @@ def test_setting_a_none_to_known_pairs_expect_fail():
@pytest.mark.xfail()
def test_setting_nones_to_known_pairs_expect_fail():
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=np.arange(10),
min_wavelength=0,
max_wavelength=10,
Expand All @@ -182,8 +184,8 @@ def test_setting_nones_to_known_pairs_expect_fail():

element_list = ["Hg", "Ar", "Xe", "Kr"]
user_atlas = Atlas(
elements="Test",
line_list="manual",
element="Test",
source="manual",
wavelengths=np.arange(10),
min_wavelength=0,
max_wavelength=10,
Expand Down
Loading

0 comments on commit e3a2a66

Please sign in to comment.