Skip to content

Commit

Permalink
refactor config
Browse files Browse the repository at this point in the history
  • Loading branch information
jveitchmichaelis committed Aug 12, 2023
1 parent 239ad95 commit d55d363
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 34 deletions.
62 changes: 32 additions & 30 deletions src/rascal/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ def _set_data(
for line in self.atlas_lines:
if (
line.wavelength < 0
or line.wavelength < self.config.data.detector_min_wave
or line.wavelength > self.config.data.detector_max_wave
or line.wavelength < self.config.detector.detector_min_wave
or line.wavelength > self.config.detector.detector_max_wave
):
self.logger.warning(
f"The provided peak {line.wavelength} is outside the given range of the detector."
Expand All @@ -236,19 +236,21 @@ def _set_num_pix(self):

# We're given a spectrum
if self.spectrum is not None:
if self.config.data.num_pix is None:
self.config.data.num_pix = len(self.spectrum)
if self.config.detector.num_pix is None:
self.config.detector.num_pix = len(self.spectrum)
else:
# Assume that if number of pixels *and* spectrum
# are provided, they should be the same length
assert (
len(self.spectrum) == self.config.data.num_pix
len(self.spectrum) == self.config.detector.num_pix
), "The length of the provided spectrum should match the num_pix"

# No spectrum provided and num_pix not provided
elif self.config.data.num_pix is None:
elif self.config.detector.num_pix is None:
if len(self.peaks) > 0:
self.config.data.num_pix = int(round(1.1 * max(self.peaks)))
self.config.detector.num_pix = int(
round(1.1 * max(self.peaks))
)
self.logger.warning(
"Neither num_pix nor spectrum is given, "
"it uses 1.1 times max(peaks) as the "
Expand All @@ -262,23 +264,23 @@ def _set_num_pix(self):
# Only user-provided num pixels, so just check that it's
# greater than the peak with the highest index
else:
if self.config.data.num_pix <= max(self.peaks):
if self.config.detector.num_pix <= max(self.peaks):
self.logger.error(
f"Maximum pixel {self.config.data.num_pix} is too low, max peak provided is {max(self.peaks)}"
f"Maximum pixel {self.config.detector.num_pix} is too low, max peak provided is {max(self.peaks)}"
)
raise ValueError

self.logger.info(f"num_pix is set to {self.config.data.num_pix}.")
self.logger.info(f"num_pix is set to {self.config.detector.num_pix}.")

# Default 1:1 mapping between pixel location and effective pixel location
if self.config.data.contiguous_range is None:
self.contiguous_pixel = list(range(self.config.data.num_pix))
if self.config.detector.contiguous_range is None:
self.contiguous_pixel = list(range(self.config.detector.num_pix))

# Otherwise assert the effective pixel array is the same as the number
# of pixels in the spectrum
else:
contiguous_ranges = np.array(
self.config.data.contiguous_range
self.config.detector.contiguous_range
).flatten()
n_ranges = int(contiguous_ranges.size / 2)
assert n_ranges % 1 == 0
Expand All @@ -289,11 +291,11 @@ def _set_num_pix(self):
self.contiguous_pixel.extend(list(np.arange(x0, x1)))

assert (
len(self.contiguous_pixel) == self.config.data.num_pix
), f"The length of the effective pixel array ({len(self.contiguous_pixel)}) should match num_pix ({self.config.data.num_pix})"
len(self.contiguous_pixel) == self.config.detector.num_pix
), f"The length of the effective pixel array ({len(self.contiguous_pixel)}) should match num_pix ({self.config.detector.num_pix})"

self.pixel_mapping_itp = itp.interp1d(
np.arange(self.config.data.num_pix), self.contiguous_pixel
np.arange(self.config.detector.num_pix), self.contiguous_pixel
)
self.peaks_effective = self.pixel_mapping_itp(np.array(self.peaks))

Expand Down Expand Up @@ -424,13 +426,13 @@ def _set_hough_properties(self):

# Start wavelength in the spectrum, +/- some tolerance
self.config.hough.min_intercept = float(
self.config.data.detector_min_wave
self.config.detector.detector_min_wave
- self.config.hough.range_tolerance
)
self.min_intercept = self.config.hough.min_intercept # TODO fix this

self.config.hough.max_intercept = float(
self.config.data.detector_min_wave
self.config.detector.detector_min_wave
+ self.config.hough.range_tolerance
)
self.max_intercept = self.config.hough.max_intercept # TODO fix this
Expand All @@ -439,7 +441,7 @@ def _set_hough_properties(self):
self.config.hough.min_slope = float(
(
(
self.config.data.detector_max_wave
self.config.detector.detector_max_wave
- self.config.hough.range_tolerance
- self.config.hough.linearity_tolerance
)
Expand All @@ -455,7 +457,7 @@ def _set_hough_properties(self):
self.config.hough.max_slope = float(
(
(
self.config.data.detector_max_wave
self.config.detector.detector_max_wave
+ self.config.hough.range_tolerance
+ self.config.hough.linearity_tolerance
)
Expand All @@ -478,13 +480,13 @@ def _set_hough_properties(self):
self.logger.info(f"Minimum slope: {self.config.hough.min_slope}")
self.logger.info(f"Maximum slope: {self.config.hough.max_slope}")
self.logger.info(
f"Minimum detector wavelength: {self.config.data.detector_min_wave}"
f"Minimum detector wavelength: {self.config.detector.detector_min_wave}"
)
self.logger.info(
f"Maximum detector wavelength: {self.config.data.detector_max_wave}"
f"Maximum detector wavelength: {self.config.detector.detector_max_wave}"
)
self.logger.info(
f"Detector range tolerance: {self.config.data.detector_edge_tolerance}"
f"Detector range tolerance: {self.config.detector.detector_edge_tolerance}"
)

def _merge_candidates(self, candidates: Union[list, np.ndarray]):
Expand Down Expand Up @@ -1172,11 +1174,11 @@ def _fit_valid(self, result: SolveResult):
min_wavelength_px = self.polyval(0, result.fit_coeffs)

if min_wavelength_px < (
self.config.data.detector_min_wave
- self.config.data.detector_edge_tolerance
self.config.detector.detector_min_wave
- self.config.detector.detector_edge_tolerance
) or min_wavelength_px > (
self.config.data.detector_min_wave
+ self.config.data.detector_edge_tolerance
self.config.detector.detector_min_wave
+ self.config.detector.detector_edge_tolerance
):
self.logger.debug(
"Lower wavelength of fit too small, "
Expand All @@ -1196,10 +1198,10 @@ def _fit_valid(self, result: SolveResult):
)

if max_wavelength_px > (
self.config.data.detector_max_wave
self.config.detector.detector_max_wave
+ self.detector_edge_tolerance
) or max_wavelength_px < (
self.config.data.detector_max_wave
self.config.detector.detector_max_wave
- self.detector_edge_tolerance
):
self.logger.debug(
Expand Down Expand Up @@ -1486,7 +1488,7 @@ def summary(self, return_string: bool = False):

output += (
"Calculated detector range: "
+ f"Start: {self.polyval(0, self.fit_coeff):1.6}, End: {self.polyval(self.config.data.num_pix, self.fit_coeff):1.6}{os.linesep}"
+ f"Start: {self.polyval(0, self.fit_coeff):1.6}, End: {self.polyval(self.config.detector.num_pix, self.fit_coeff):1.6}{os.linesep}"
)

output += "RMS of the best fit solution: {self.rms}{os.linesep}"
Expand Down
11 changes: 7 additions & 4 deletions src/rascal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@

@dataclass
class DataConfig:
filename: str = ""
num_pix: Optional[int] = None
contiguous_range: Optional[List[float]] = field(default=None)
peaks: Optional[List[float]] = field(default=None)
spectrum: Optional[List[float]] = field(default=None)


@dataclass
class DetectorConfig:
contiguous_range: Optional[List[float]] = field(default=None)
num_pix: Optional[int] = None
detector_min_wave: float = 3000.0
detector_max_wave: float = 9000.0
detector_edge_tolerance: float = 200.0
Expand Down Expand Up @@ -131,7 +134,7 @@ class CalibratorConfig:
log_level: str = "info"
hide_progress: bool = False
atlases: List[Atlas] = MISSING

data: DataConfig = field(default_factory=DataConfig)
detector: DetectorConfig = field(default_factory=DetectorConfig)
hough: HoughConfig = field(default_factory=HoughConfig)
ransac: RansacConfig = field(default_factory=RansacConfig)

0 comments on commit d55d363

Please sign in to comment.