diff --git a/tests/test_astra_import.py b/tests/test_astra_import.py index fd688145..2f1547c7 100644 --- a/tests/test_astra_import.py +++ b/tests/test_astra_import.py @@ -1,4 +1,5 @@ import numpy as np +import torch import cheetah @@ -40,3 +41,21 @@ def test_astra_to_particle_beam(): assert np.allclose(beam.sigma_p.cpu().numpy(), 0.0022804534528404474) assert np.allclose(beam.energy.cpu().numpy(), 107315902.44394557) assert np.allclose(beam.total_charge.cpu().numpy(), 5.000000000010205e-13) + + +def test_astra_to_parameter_beam_dtypes(): + """Test that Astra beams are correctly loaded into particle beams.""" + beam = cheetah.ParameterBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001") + + assert beam.mu_x.dtype == torch.float32 + assert beam.mu_xp.dtype == torch.float32 + assert beam.mu_y.dtype == torch.float32 + assert beam.mu_yp.dtype == torch.float32 + assert beam.sigma_x.dtype == torch.float32 + assert beam.sigma_xp.dtype == torch.float32 + assert beam.sigma_y.dtype == torch.float32 + assert beam.sigma_yp.dtype == torch.float32 + assert beam.sigma_s.dtype == torch.float32 + assert beam.sigma_p.dtype == torch.float32 + assert beam.energy.dtype == torch.float32 + assert beam.total_charge.dtype == torch.float32