From e3a2a66411dc44de41460671531bd23b8ea9f731 Mon Sep 17 00:00:00 2001 From: Josh Veitch-Michaelis Date: Sat, 12 Aug 2023 22:03:43 +0200 Subject: [PATCH] refactor atlas attribute names --- examples/example_lt_sprat_manual_atlas.py | 5 +- src/rascal/atlas.py | 48 ++++------ test/test_atlas.py | 48 +++++----- test/test_atlas_yaml_config.py | 102 +++------------------- test/test_config.yaml | 11 +++ test/test_effective_pixel.py | 29 ++++-- test/test_fitted_coefficients.py | 13 +-- test/test_hough_transform.py | 7 +- test/test_lt_sprat_manual_atlas.py | 10 ++- test/test_matched_peaks.py | 5 +- test/test_polynomial_fit.py | 19 ++-- test/test_synthetic_calibration.py | 9 +- 12 files changed, 126 insertions(+), 180 deletions(-) create mode 100644 test/test_config.yaml diff --git a/examples/example_lt_sprat_manual_atlas.py b/examples/example_lt_sprat_manual_atlas.py index 73d650d..5c53cd8 100644 --- a/examples/example_lt_sprat_manual_atlas.py +++ b/examples/example_lt_sprat_manual_atlas.py @@ -3,10 +3,11 @@ import numpy as np from astropy.io import fits from matplotlib import pyplot as plt +from scipy.signal import find_peaks + from rascal import util from rascal.atlas import Atlas from rascal.calibrator import Calibrator -from scipy.signal import find_peaks # Load the LT SPRAT data base_dir = os.path.dirname(os.path.abspath(__file__)) @@ -109,7 +110,7 @@ } atlas = Atlas( - line_list="manual", + source="manual", wavelengths=sprat_atlas_lines, min_wavelength=3800.0, max_wavelength=8000.0, diff --git a/src/rascal/atlas.py b/src/rascal/atlas.py index 47ee383..ee692f1 100644 --- a/src/rascal/atlas.py +++ b/src/rascal/atlas.py @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) -class LineSource(Enum): +class NistSource(Enum): NIST_STRONG = auto() NIST_ALL = auto() @@ -76,8 +76,8 @@ class Atlas: min_wavelength: float = MISSING max_wavelength: float = MISSING - elements: Optional[Any] = None - line_list: str = "nist" + element: Optional[Any] = None + source: str = "nist" range_tolerance: float = 0.0 min_intensity: Optional[float] = 10.0 min_distance: float = 10.0 @@ -90,27 +90,24 @@ class Atlas: intensities: Optional[List[float]] = field(default=None) use_accurate_lines: Optional[bool] = True atlas_lines: Optional[List[AtlasLine]] = field(default=None) - nist_source: LineSource = LineSource.NIST_STRONG + nist_source: NistSource = NistSource.NIST_STRONG def __post_init__(self): - if isinstance(self.elements, str): - self.elements = [self.elements] - self.min_atlas_wavelength = self.min_wavelength - self.range_tolerance self.max_atlas_wavelength = self.max_wavelength + self.range_tolerance self.atlas_lines = [] logger.info( - f"Loading lines from {self.line_list} list between {self.min_atlas_wavelength} and {self.max_atlas_wavelength} for elements: {set(self.elements)}" + f"Loading lines from {self.source} between {self.min_atlas_wavelength} and {self.max_atlas_wavelength} for element: {set(self.element)}" ) logger.info( f"Filtering lines by intensity > {self.min_intensity} and separation > {self.min_distance} Å" ) - if self.line_list == "manual": + if self.source == "manual": self.add_manual() - elif self.line_list == "nist": + elif self.source == "nist": self.add_nist() else: raise NotImplementedError @@ -121,24 +118,12 @@ def add_manual(self): if not isinstance(self.wavelengths, list): self.wavelengths = list(self.wavelengths) - # If a single element is provided, assume that - # all lines are from this element - if len(self.elements) == 1: - self.elements = self.elements * len(self.wavelengths) - # Empty intensity if self.intensities is None: self.intensities = [0] * len(self.wavelengths) elif not isinstance(self.intensities, list): self.intensities = list(self.intensities) - assert len(self.elements) == len(self.wavelengths), ValueError( - "Input elements and wavelengths have different length." - ) - assert len(self.elements) == len(self.intensities), ValueError( - "Input elements and intensities have different length." - ) - if self.vacuum: self.wavelengths = vacuum_to_air_wavelength( @@ -148,8 +133,8 @@ def add_manual(self): self.relative_humidity, ) - for element, wavelength, intensity in list( - zip(self.elements, self.wavelengths, self.intensities) + for wavelength, intensity in list( + zip(self.wavelengths, self.intensities) ): if wavelength < (self.min_wavelength - self.range_tolerance): logger.warning( @@ -170,16 +155,15 @@ def add_manual(self): self.atlas_lines.append( AtlasLine( wavelength=wavelength, - element=element, + element=self.element, intensity=intensity, source="user", ) ) def add_nist(self): - assert len(self.elements) == 1 - s = self.elements[0].split(" ") + s = self.element.split(" ") if len(s) == 2: states = s[1] @@ -458,7 +442,7 @@ def wavelengths(self): def nist_files( element: str, states: List[str] = ["I", "II"], - source: LineSource = LineSource.NIST_STRONG, + source: NistSource = NistSource.NIST_STRONG, ) -> List[str]: """ Locate atlas files for a particular element and an optional state. By default, only the I and II @@ -486,9 +470,9 @@ def nist_files( get_ref = lambda path: (import_resources.files(__package__) / path) - if source == LineSource.NIST_ALL: + if source == NistSource.NIST_ALL: root = f"arc_lines/nist_clean_" - elif source == LineSource.NIST_STRONG: + elif source == NistSource.NIST_STRONG: root = f"arc_lines/strong_lines/" else: raise NotImplementedError( @@ -573,7 +557,7 @@ def open_line_list( def nist_lines( element, states=["I", "II"], - source=LineSource.NIST_STRONG, + source=NistSource.NIST_STRONG, only_accurate=True, ) -> List[Dict]: """Load NIST reference lines for the specified element and ionisation states. @@ -596,7 +580,7 @@ def nist_lines( """ files = nist_files(element, states, source) - if (source == LineSource.NIST_STRONG) and only_accurate: + if (source == NistSource.NIST_STRONG) and only_accurate: only_accurate = False logger.debug( "Disabling accurate line filter as NIST strong lines do not have an associated accuracy." diff --git a/test/test_atlas.py b/test/test_atlas.py index 50889c3..60ce27f 100644 --- a/test/test_atlas.py +++ b/test/test_atlas.py @@ -38,8 +38,8 @@ def test_load_nist_single(elements): time_start = time.perf_counter() nist_atlas = Atlas( - elements=[element], - line_list="nist", + element=element, + source="nist", min_wavelength=min_wavelength, max_wavelength=max_wavelength, brightest_n_lines=None, @@ -63,8 +63,8 @@ def test_load_with_min_intensity(): min_intensity = 10 nist_atlas = Atlas( - elements=["Xe"], - line_list="nist", + element="Xe", + source="nist", min_wavelength=4000, max_wavelength=8000, min_intensity=min_intensity, @@ -75,24 +75,26 @@ def test_load_with_min_intensity(): assert i >= min_intensity, "Line intensity is below minimum" +""" +# TODO: Move to Atlas collection def test_load_nist_all(elements): + collection = AtlasCo nist_atlas = Atlas( - elements=elements, + element=elements, line_list="nist", min_wavelength=0, max_wavelength=8000, ) - - assert len(nist_atlas) > 0 +""" def test_check_nonzero_length_common_elements_single(): for element in ["Xe", "Kr", "Ar", "Ne", "Cu"]: nist_atlas = Atlas( - elements=element, - line_list="nist", + element=element, + source="nist", min_wavelength=4000, max_wavelength=8000, ) @@ -102,8 +104,8 @@ def test_check_nonzero_length_common_elements_single(): def test_load_single_line(): user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=[5.0], min_wavelength=0, max_wavelength=10, @@ -113,8 +115,8 @@ def test_load_single_line(): def test_load_mutliple_lines(): user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, @@ -124,8 +126,8 @@ def test_load_mutliple_lines(): def test_setting_a_known_pair(): user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, @@ -139,8 +141,8 @@ def test_setting_a_known_pair(): def test_setting_known_pairs(): user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, @@ -155,8 +157,8 @@ def test_setting_known_pairs(): @pytest.mark.xfail() def test_setting_a_none_to_known_pairs_expect_fail(): user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, @@ -169,8 +171,8 @@ def test_setting_a_none_to_known_pairs_expect_fail(): @pytest.mark.xfail() def test_setting_nones_to_known_pairs_expect_fail(): user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, @@ -182,8 +184,8 @@ def test_setting_nones_to_known_pairs_expect_fail(): element_list = ["Hg", "Ar", "Xe", "Kr"] user_atlas = Atlas( - elements="Test", - line_list="manual", + element="Test", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, diff --git a/test/test_atlas_yaml_config.py b/test/test_atlas_yaml_config.py index 99e9a06..3a5cdd3 100644 --- a/test/test_atlas_yaml_config.py +++ b/test/test_atlas_yaml_config.py @@ -5,118 +5,44 @@ import pkg_resources import pytest import yaml -from rascal.atlas import Atlas # Suppress tqdm output from tqdm import tqdm +from rascal.atlas import Atlas + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) base_dir = os.path.dirname(__file__) def test_load_atlas_from_yaml_file(): - atlas = Atlas( - elements="Test", - line_list="manual", - wavelengths=np.arange(10), - min_wavelength=0, - max_wavelength=10, - ) - atlas.add_config( - yaml_config=pkg_resources.resource_filename( - "rascal", "../../atlas_yaml_template.yaml" - ) - ) + yaml_config = os.path.join(os.path.dirname(__file__), "test_config.yaml") + + with open(yaml_config, "r") as stream: + config = yaml.safe_load(stream)["atlases"] + + for c in config: + _ = Atlas(**c) def test_load_atlas_from_pyyaml_object(): - with open( - pkg_resources.resource_filename( - "rascal", "../../atlas_yaml_template.yaml" - ), - "r", - ) as stream: - yaml_object = yaml.safe_load(stream) - atlas = Atlas( - elements="Test", - line_list="manual", - wavelengths=np.arange(10), - min_wavelength=0, - max_wavelength=10, - ) - atlas.add_config(config=yaml_object, y_type="object") + pass def test_load_atlas_config_user_linelist(): - with open( - pkg_resources.resource_filename( - "rascal", "../../atlas_yaml_template.yaml" - ), - "r", - ) as stream: - yaml_object = yaml.safe_load(stream) - yaml_object["linelist"] = "user" - yaml_object["element_list"] = np.array(["Xe"] * 10) - yaml_object["wavelength_list"] = np.arange(10) - atlas = Atlas( - elements="Test", - line_list="manual", - wavelengths=np.arange(10), - min_wavelength=0, - max_wavelength=10, - ) - atlas.load_config(config=yaml_object) + pass @pytest.mark.xfail() def test_load_atlas_config_expect_fail_ytype(): - atlas = Atlas( - elements="Test", - line_list="manual", - wavelengths=np.arange(10), - min_wavelength=0, - max_wavelength=10, - ) - atlas.load_config(np.arange(100), y_type="bla") + pass @pytest.mark.xfail() def test_load_atlas_config_expect_fail_linelist_type(): - with open( - pkg_resources.resource_filename( - "rascal", "../../atlas_yaml_template.yaml" - ), - "r", - ) as stream: - yaml_object = yaml.safe_load(stream) - yaml_object["linelist"] = "blabla" - atlas = Atlas( - elements="Test", - line_list="manual", - wavelengths=np.arange(10), - min_wavelength=0, - max_wavelength=10, - ) - atlas.load_config(yaml_config=yaml_object, y_type="object") + pass def test_save_atlas_config(): - with open( - pkg_resources.resource_filename( - "rascal", "../../atlas_yaml_template.yaml" - ), - "r", - ) as stream: - yaml_object = yaml.safe_load(stream) - atlas = Atlas( - elements="Test", - line_list="manual", - wavelengths=np.arange(10), - min_wavelength=0, - max_wavelength=10, - ) - atlas.load_config(yaml_config=yaml_object, y_type="object") - atlas.save_config( - os.path.join(base_dir, "test_output", "test_atlas_config.yaml") - ) + pass diff --git a/test/test_config.yaml b/test/test_config.yaml new file mode 100644 index 0000000..a0f2111 --- /dev/null +++ b/test/test_config.yaml @@ -0,0 +1,11 @@ +atlases: + - + element: ['Hg'] + min_wavelength: 4000 + max_wavelength: 7000 + min_intensity: 50 + - + element: ['Ar'] + min_wavelength: 7000 + max_wavelength: 9000 + min_intensity: 50 \ No newline at end of file diff --git a/test/test_effective_pixel.py b/test/test_effective_pixel.py index e687e1b..f12f8b0 100644 --- a/test/test_effective_pixel.py +++ b/test/test_effective_pixel.py @@ -3,13 +3,14 @@ import numpy as np import pytest from matplotlib.font_manager import X11FontDirectories -from rascal.atlas import Atlas -from rascal.calibrator import Calibrator -from rascal.synthetic import SyntheticSpectrum # Suppress tqdm output from tqdm import tqdm +from rascal.atlas import Atlas +from rascal.calibrator import Calibrator +from rascal.synthetic import SyntheticSpectrum + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) # Create a test spectrum with a simple linear relationship @@ -48,22 +49,36 @@ } +def run_calibration(config): + + a = AtlasCollection(config) + c = Calibrator(atlas_lines=a, config=config) + c.fit() + + return results + + def test_effective_pixel_not_affecting_fit_int_peaks(): # Set up the calibrator with the pixel values of our # wavelengths a = Atlas( elements="Test", - line_list="manual", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, ) + + _config = config.copy() + _config["ransac"]["max_tries"] = 2000 + _config["ransac"]["degree"] = 3 + c = Calibrator( - peaks=peaks.astype("int"), atlas_lines=a.atlas_lines, config=config + peaks=peaks.astype("int"), atlas_lines=a.atlas_lines, config=_config ) # And let's try and fit... - res = c.fit(max_tries=2000, fit_deg=3) + res = c.fit() assert res @@ -91,7 +106,7 @@ def test_effective_pixel_not_affecting_fit_perfect_peaks(): # wavelengths a = Atlas( elements="Test", - line_list="manual", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, diff --git a/test/test_fitted_coefficients.py b/test/test_fitted_coefficients.py index 888c32c..ac5d6e0 100644 --- a/test/test_fitted_coefficients.py +++ b/test/test_fitted_coefficients.py @@ -3,15 +3,16 @@ import numpy as np from astropy.io import fits -from rascal import util -from rascal.atlas import Atlas -from rascal.calibrator import Calibrator from scipy import interpolate from scipy.signal import find_peaks # Suppress tqdm output from tqdm import tqdm +from rascal import util +from rascal.atlas import Atlas +from rascal.calibrator import Calibrator + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) base_dir = os.path.dirname(os.path.abspath(__file__)) @@ -174,7 +175,7 @@ def test_gmos_fit(): element = ["CuAr"] * len(gmos_atlas_lines) atlas = Atlas( - line_list="manual", + source="manual", wavelengths=gmos_atlas_lines, min_wavelength=5000.0, max_wavelength=9500.0, @@ -231,7 +232,7 @@ def test_osiris_fit(): element = ["HgAr"] * len(osiris_atlas_lines) atlas = Atlas( - line_list="manual", + source="manual", wavelengths=osiris_atlas_lines, min_wavelength=3500.0, max_wavelength=8000.0, @@ -368,7 +369,7 @@ def test_sprat_fit(): element = ["Xe"] * len(sprat_atlas_lines) atlas = Atlas( - line_list="manual", + source="manual", wavelengths=sprat_atlas_lines, min_wavelength=3500.0, max_wavelength=8000.0, diff --git a/test/test_hough_transform.py b/test/test_hough_transform.py index 1db9c5c..98419c5 100644 --- a/test/test_hough_transform.py +++ b/test/test_hough_transform.py @@ -3,12 +3,13 @@ import numpy as np import pytest -from rascal.atlas import Atlas -from rascal.calibrator import Calibrator, HoughTransform # Suppress tqdm output from tqdm import tqdm +from rascal.atlas import Atlas +from rascal.calibrator import Calibrator, HoughTransform + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) HERE = os.path.dirname(os.path.realpath(__file__)) @@ -169,7 +170,7 @@ def test_extending_ht_expect_fail(): def test_loading_ht_into_calibrator(): user_atlas = Atlas( elements="Test", - line_list="manual", + source="manual", wavelengths=np.arange(10), min_wavelength=0, max_wavelength=10, diff --git a/test/test_lt_sprat_manual_atlas.py b/test/test_lt_sprat_manual_atlas.py index 7db4143..9423db5 100644 --- a/test/test_lt_sprat_manual_atlas.py +++ b/test/test_lt_sprat_manual_atlas.py @@ -4,14 +4,15 @@ import numpy as np import pytest from astropy.io import fits -from rascal import util -from rascal.atlas import Atlas -from rascal.calibrator import Calibrator from scipy.signal import find_peaks # Suppress tqdm output from tqdm import tqdm +from rascal import util +from rascal.atlas import Atlas +from rascal.calibrator import Calibrator + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) HERE = os.path.dirname(os.path.realpath(__file__)) @@ -82,7 +83,7 @@ element = ["Xe"] * len(sprat_atlas_lines) user_atlas = Atlas( - line_list="manual", + source="manual", wavelengths=sprat_atlas_lines, min_wavelength=3800.0, max_wavelength=8200.0, @@ -115,6 +116,7 @@ "top_n_candidate": 5, "filter_close": True, }, + "atlases": {...}, } diff --git a/test/test_matched_peaks.py b/test/test_matched_peaks.py index 6b23917..59fdc4e 100644 --- a/test/test_matched_peaks.py +++ b/test/test_matched_peaks.py @@ -3,10 +3,11 @@ import numpy as np import pytest from astropy.io import fits +from scipy.signal import find_peaks + from rascal import util from rascal.atlas import Atlas from rascal.calibrator import Calibrator -from scipy.signal import find_peaks @pytest.fixture(scope="session") @@ -111,7 +112,7 @@ def calibrator(base_dir): element = ["Xe"] * len(wavelengths) user_atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths, elements=element, min_wavelength=3500.0, diff --git a/test/test_polynomial_fit.py b/test/test_polynomial_fit.py index 59323f6..5da458e 100644 --- a/test/test_polynomial_fit.py +++ b/test/test_polynomial_fit.py @@ -1,12 +1,13 @@ from functools import partialmethod import numpy as np -from rascal.atlas import Atlas -from rascal.calibrator import Calibrator # Suppress tqdm output from tqdm import tqdm +from rascal.atlas import Atlas +from rascal.calibrator import Calibrator + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) np.random.seed(0) @@ -27,7 +28,7 @@ def test_linear_fit(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_linear, min_wavelength=3500.0, max_wavelength=8000.0, @@ -68,7 +69,7 @@ def test_linear_fit(): def test_manual_refit(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_linear, min_wavelength=3500.0, max_wavelength=8000.0, @@ -106,7 +107,7 @@ def test_manual_refit(): def test_manual_refit_remove_points(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_linear, min_wavelength=3500.0, max_wavelength=8000.0, @@ -147,7 +148,7 @@ def test_manual_refit_remove_points(): def test_manual_refit_add_points(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_linear, min_wavelength=3500.0, max_wavelength=8000.0, @@ -187,7 +188,7 @@ def test_manual_refit_add_points(): def test_quadratic_fit(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_quadratic, min_wavelength=3500.0, max_wavelength=8000.0, @@ -231,7 +232,7 @@ def test_quadratic_fit(): def test_quadratic_fit_legendre(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_quadratic, min_wavelength=3500.0, max_wavelength=8000.0, @@ -276,7 +277,7 @@ def test_quadratic_fit_legendre(): def test_quadratic_fit_chebyshev(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=wavelengths_quadratic, min_wavelength=3500.0, max_wavelength=8000.0, diff --git a/test/test_synthetic_calibration.py b/test/test_synthetic_calibration.py index 8dd36ea..96ae8e8 100644 --- a/test/test_synthetic_calibration.py +++ b/test/test_synthetic_calibration.py @@ -2,13 +2,14 @@ from functools import partialmethod import numpy as np -from rascal.atlas import Atlas -from rascal.calibrator import Calibrator -from rascal.synthetic import SyntheticSpectrum # Suppress tqdm output from tqdm import tqdm +from rascal.atlas import Atlas +from rascal.calibrator import Calibrator +from rascal.synthetic import SyntheticSpectrum + logger = logging.getLogger(__name__) tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) @@ -37,7 +38,7 @@ def test_default(): atlas = Atlas( - line_list="manual", + source="manual", wavelengths=waves, min_wavelength=min_wavelength, max_wavelength=max_wavelength,