Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve unit tests for model_parameters; add read parameters from file #1081

Merged
merged 4 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion simtools/applications/derive_psf_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

Example
-------
LST-1 Prod5
LSTN-01 Prod5

Runtime < 3 min.

Expand Down
14 changes: 2 additions & 12 deletions simtools/applications/validate_cumulative_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import yaml

import simtools.utils.general as gen
from simtools.configuration import configurator
Expand Down Expand Up @@ -111,11 +110,6 @@ def _parse(label):
help="Data file name with the measured PSF vs radius [cm]",
type=str,
)
config.parser.add_argument(
"--mc_parameter_file",
help="Yaml file with the model parameters to be replaced",
type=str,
)
return config.initialize(db_config=True, simulation_model="telescope")


Expand Down Expand Up @@ -149,12 +143,8 @@ def main(): # noqa: D103
model_version=args_dict["model_version"],
label=label,
)

# New parameters
if args_dict.get("pars", None):
with open(args_dict["pars"], encoding="utf-8") as file:
new_pars = yaml.safe_load(file)
tel_model.change_multiple_parameters(**new_pars)
if args_dict.get("telescope_model_file"):
tel_model.change_multiple_parameters_from_file(args_dict["telescope_model_file"])

ray = RayTracing.from_kwargs(
telescope_model=tel_model,
Expand Down
9 changes: 9 additions & 0 deletions simtools/configuration/commandline_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ def initialize_simulation_model_arguments(self, model_options):
help="telescope model name (e.g., LSTN-01, SSTS-design, ...)",
type=self.telescope,
)
_job_group.add_argument(
"--telescope_model_file",
help=(
"File with changes to telescope model "
" (yaml format; experimental with insufficient validation steps)."
),
type=Path,
required=False,
)

if "layout" in model_options or "layout_file" in model_options:
_job_group = self._add_model_option_layout(_job_group, "layout_file" in model_options)
Expand Down
28 changes: 24 additions & 4 deletions simtools/model/model_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def get_parameter_type(self, par_name):
"""
parameter_dict = self._get_parameter_dict(par_name)
try:
return parameter_dict.get("type")
return parameter_dict["type"]
except KeyError:
self._logger.debug(f"Parameter {par_name} does not have a type")
self._logger.debug(f"Parameter {par_name} does not have a type.")
return None

def get_parameter_file_flag(self, par_name):
Expand All @@ -228,8 +228,7 @@ def get_parameter_file_flag(self, par_name):
"""
parameter_dict = self._get_parameter_dict(par_name)
try:
if parameter_dict.get("file"):
return True
return parameter_dict["file"]
except KeyError:
self._logger.debug(f"Parameter {par_name} does not have a file associated with it.")
return False
Expand Down Expand Up @@ -447,6 +446,27 @@ def change_parameter(self, par_name, value):

self._is_config_file_up_to_date = False

def change_multiple_parameters_from_file(self, file_name):
"""
Change values of multiple existing parameters in the model from a file.

This function does not modify the DB, it affects only the current instance.
Experimental feature: insufficient validation of parameters.

Parameters
----------
file_name: str
File containing the parameters to be changed.
"""
self._logger.warning(
"Changing multiple parameters from file is an experimental feature."
"Insufficient validation of parameters."
)
self._logger.debug(f"Changing parameters from file {file_name}")
self.change_multiple_parameters(
**gen.collect_data_from_file_or_dict(file_name=file_name, in_dict=None)
)

def change_multiple_parameters(self, **kwargs):
"""
Change the value of multiple existing parameters in the model.
Expand Down
141 changes: 141 additions & 0 deletions tests/unit_tests/model/test_model_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,40 @@

import copy
import logging
from pathlib import Path

import pytest
from astropy import units as u

import simtools.utils.general as gen
from simtools.db.db_handler import DatabaseHandler
from simtools.model.model_parameter import InvalidModelParameterError
from simtools.model.telescope_model import TelescopeModel

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)


def test_get_parameter_type(telescope_model_lst, caplog):

assert telescope_model_lst.get_parameter_type("num_gains") == "int64"
telescope_model_copy = copy.deepcopy(telescope_model_lst)
telescope_model_copy._parameters["num_gains"].pop("type")
with caplog.at_level(logging.DEBUG):
assert telescope_model_copy.get_parameter_type("num_gains") is None
assert "Parameter num_gains does not have a type." in caplog.text


def test_get_parameter_file_flag(telescope_model_lst, caplog):

assert telescope_model_lst.get_parameter_file_flag("num_gains") is False
telescope_model_copy = copy.deepcopy(telescope_model_lst)
telescope_model_copy._parameters["num_gains"].pop("file")
with caplog.at_level(logging.DEBUG):
assert telescope_model_copy.get_parameter_file_flag("num_gains") is False
assert "Parameter num_gains does not have a file associated with it." in caplog.text


def test_get_parameter_dict(telescope_model_lst):
tel_model = telescope_model_lst
assert isinstance(tel_model._get_parameter_dict("num_gains"), dict)
Expand Down Expand Up @@ -94,6 +116,57 @@ def test_handling_parameters(telescope_model_lst):
tel_model._get_parameter_dict("bla_bla")


def test_print_parameters(telescope_model_lst, capsys):
tel_model = telescope_model_lst
tel_model.print_parameters()
assert "quantum_efficiency" in capsys.readouterr().out


def test_set_config_file_directory_and_name(telescope_model_lst, caplog):
telescope_copy = copy.deepcopy(telescope_model_lst)
telescope_copy.name = None
with caplog.at_level(logging.DEBUG):
telescope_copy._set_config_file_directory_and_name()
assert "Config file path" not in caplog.text


def test_get_simulation_software_parameters(telescope_model_lst):
assert isinstance(telescope_model_lst.get_simulation_software_parameters("corsika"), dict)


def test_load_simulation_software_parameter(telescope_model_lst, caplog):
telescope_copy = copy.deepcopy(telescope_model_lst)
telescope_copy._simulation_config_parameters = {"not_corsika": {}, "not_simtel": {}}
with caplog.at_level(logging.WARNING):
telescope_copy._load_simulation_software_parameter()
assert "No not_corsika parameters found for North" in caplog.text


def test_load_parameters_from_db(telescope_model_lst, mocker):
telescope_copy = copy.deepcopy(telescope_model_lst)
mock_db = mocker.patch.object(DatabaseHandler, "get_model_parameters")
telescope_copy._load_parameters_from_db()
mock_db.assert_called_once()

telescope_copy.db = None
telescope_copy._load_parameters_from_db()
not mock_db.assert_called_once()


def test_extra_labels(telescope_model_lst):
telescope_copy = copy.deepcopy(telescope_model_lst)
assert telescope_copy._extra_label is None
assert telescope_copy.extra_label == ""

telescope_copy.set_extra_label("test")
assert telescope_copy._extra_label == "test"
assert telescope_copy.extra_label == "test"


def test_get_simtel_parameters(telescope_model_lst):
assert isinstance(telescope_model_lst.get_simtel_parameters(), dict)


def test_change_parameter(telescope_model_lst):
tel_model = telescope_model_lst

Expand Down Expand Up @@ -122,6 +195,27 @@ def test_change_parameter(telescope_model_lst):
logger.info("Testing changing mirror_focal_length to a nonsense string")
tel_model.change_parameter("mirror_focal_length", "bla_bla")

with pytest.raises(InvalidModelParameterError, match="Parameter bla_bla not in the model"):
tel_model.change_parameter("bla_bla", 9999.9)


def test_change_multiple_parameters_from_file(telescope_model_lst, mocker):
telescope_copy = copy.deepcopy(telescope_model_lst)
mocker_gen = mocker.patch(
"simtools.utils.general.collect_data_from_file_or_dict", return_value={}
)
telescope_copy.change_multiple_parameters_from_file(file_name="test_file")
mocker_gen.assert_called_once()


def test_change_multiple_parameters(telescope_model_lst, mocker):
telescope_copy = copy.deepcopy(telescope_model_lst)
mock_change = mocker.patch.object(TelescopeModel, "change_parameter")
telescope_copy.change_multiple_parameters(**{"camera_pixels": 9999, "mirror_focal_length": 55})
mock_change.assert_any_call("camera_pixels", 9999)
mock_change.assert_any_call("mirror_focal_length", 55)
assert not telescope_copy._is_config_file_up_to_date


def test_flen_type(telescope_model_lst):
tel_model = telescope_model_lst
Expand Down Expand Up @@ -199,3 +293,50 @@ def test_export_derived_files(io_handler, db_config):
assert tel_model.config_file_directory.joinpath(
"ray-tracing-North-LST-1-d10.0-za20.0_validate_optics.ecsv"
).exists()


def test_export_parameter_file(telescope_model_lst, mocker):
parameter = "array_coordinates_UTM"
file_path = "tests/resources/telescope_positions-North-ground.ecsv"
telescope_copy = copy.deepcopy(telescope_model_lst)
mock_copy = mocker.patch("shutil.copy")
telescope_copy.export_parameter_file(par_name=parameter, file_path=file_path)
mock_copy.assert_called_once_with(file_path, telescope_copy.config_file_directory)


def test_export_model_files(telescope_model_lst, mocker):
telescope_copy = copy.deepcopy(telescope_model_lst)
mock_db = mocker.patch.object(DatabaseHandler, "export_model_files")
telescope_copy.export_model_files()
assert telescope_copy._is_exported_model_files_up_to_date
mock_db.assert_called_once()

telescope_copy._added_parameter_files = ["test_file"]
with pytest.raises(KeyError):
telescope_copy.export_model_files()


def test_config_file_path(telescope_model_lst, mocker):
telescope_copy = copy.deepcopy(telescope_model_lst)
telescope_copy._config_file_path = None
mock_config = mocker.patch.object(TelescopeModel, "_set_config_file_directory_and_name")
telescope_copy.config_file_path
mock_config.assert_called_once()

telescope_copy._config_file_path = Path("test_path")
assert telescope_copy.config_file_path == Path("test_path")
not mock_config.assert_called_once()


def test_get_config_file(telescope_model_lst, mocker):
assert isinstance(telescope_model_lst.get_config_file(), Path)

telescope_copy = copy.deepcopy(telescope_model_lst)
telescope_copy._is_config_file_up_to_date = False
mock_export = mocker.patch.object(TelescopeModel, "export_config_file")
telescope_copy.get_config_file()
mock_export.assert_called_once()

telescope_copy._is_config_file_up_to_date = False
telescope_copy.get_config_file(no_export=True)
not mock_export.assert_called_once()