Skip to content

Commit

Permalink
Silence PyTorch weights_only warning
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Aug 26, 2024
1 parent 50b0982 commit dc598db
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
10 changes: 7 additions & 3 deletions tests/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_dipole_bmadx_tracking(dtype):
Test that the results of tracking through a dipole with the `"bmadx"` tracking
method match the results from Bmad-X.
"""
incoming = torch.load("tests/resources/bmadx/incoming_beam.pt")
incoming = torch.load(
"tests/resources/bmadx/incoming_beam.pt", weights_only=False
).to(dtype)
mc2 = torch.tensor(
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6,
dtype=dtype,
Expand Down Expand Up @@ -111,11 +113,13 @@ def test_dipole_bmadx_tracking(dtype):
outgoing_cheetah_bmadx = segment_cheetah_bmadx.track(incoming)

# Load reference result computed with Bmad-X
outgoing_bmadx = torch.load("tests/resources/bmadx/outgoing_bmadx_dipole.pt")
outgoing_bmadx = torch.load(
"tests/resources/bmadx/outgoing_bmadx_dipole.pt", weights_only=False
)

assert torch.allclose(
outgoing_cheetah_bmadx.particles,
outgoing_bmadx if dtype == torch.float64 else outgoing_bmadx.float(),
outgoing_bmadx.to(dtype),
rtol=1e-14 if dtype == torch.float64 else 0.00001,
atol=1e-14 if dtype == torch.float64 else 1e-8,
)
10 changes: 7 additions & 3 deletions tests/test_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def test_drift_bmadx_tracking(dtype):
Test that the results of tracking through a drift with the `"bmadx"` tracking method
match the results from Bmad-X.
"""
incoming_beam = torch.load("tests/resources/bmadx/incoming_beam.pt")
incoming_beam = torch.load(
"tests/resources/bmadx/incoming_beam.pt", weights_only=False
).to(dtype)
drift = cheetah.Drift(
length=torch.tensor([1.0]), tracking_method="bmadx", dtype=dtype
)
Expand All @@ -76,11 +78,13 @@ def test_drift_bmadx_tracking(dtype):
outgoing_beam = drift.track(incoming_beam)

# Load reference result computed with Bmad-X
outgoing_bmadx = torch.load("tests/resources/bmadx/outgoing_bmadx_drift.pt")
outgoing_bmadx = torch.load(
"tests/resources/bmadx/outgoing_bmadx_drift.pt", weights_only=False
)

assert torch.allclose(
outgoing_beam.particles,
outgoing_bmadx if dtype == torch.float64 else outgoing_bmadx.float(),
outgoing_bmadx.to(dtype),
atol=1e-14 if dtype == torch.float64 else 0.00001,
rtol=1e-14 if dtype == torch.float64 else 1e-8,
)
10 changes: 7 additions & 3 deletions tests/test_quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def test_quadrupole_bmadx_tracking(dtype):
Test that the results of tracking through a quadrupole with the `"bmadx"` tracking
method match the results from Bmad-X.
"""
incoming = torch.load("tests/resources/bmadx/incoming_beam.pt")
incoming = torch.load(
"tests/resources/bmadx/incoming_beam.pt", weights_only=False
).to(dtype)
quadrupole = Quadrupole(
length=torch.tensor([1.0]),
k1=torch.tensor([10.0]),
Expand All @@ -156,11 +158,13 @@ def test_quadrupole_bmadx_tracking(dtype):
outgoing = segment.track(incoming)

# Load reference result computed with Bmad-X
outgoing_bmadx = torch.load("tests/resources/bmadx/outgoing_bmadx_quadrupole.pt")
outgoing_bmadx = torch.load(
"tests/resources/bmadx/outgoing_bmadx_quadrupole.pt", weights_only=False
)

assert torch.allclose(
outgoing.particles,
outgoing_bmadx if dtype == torch.float64 else outgoing_bmadx.float(),
outgoing_bmadx.to(dtype),
atol=1e-14 if dtype == torch.float64 else 0.00001,
rtol=1e-14 if dtype == torch.float64 else 1e-8,
)
Expand Down
10 changes: 7 additions & 3 deletions tests/test_transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype):
Test that the results of tracking through a TDC with the `"bmadx"` tracking method
match the results from Bmad-X.
"""
incoming_beam = torch.load("tests/resources/bmadx/incoming_beam.pt")
incoming_beam = torch.load(
"tests/resources/bmadx/incoming_beam.pt", weights_only=False
).to(dtype)
tdc = cheetah.TransverseDeflectingCavity(
length=torch.tensor([1.0]),
voltage=torch.tensor([1e7]),
Expand All @@ -24,11 +26,13 @@ def test_transverse_deflecting_cavity_bmadx_tracking(dtype):
outgoing_beam = tdc.track(incoming_beam)

# Load reference result computed with Bmad-X
outgoing_bmadx = torch.load("tests/resources/bmadx/outgoing_bmadx_crab_cavity.pt")
outgoing_bmadx = torch.load(
"tests/resources/bmadx/outgoing_bmadx_crab_cavity.pt", weights_only=False
)

assert torch.allclose(
outgoing_beam.particles,
outgoing_bmadx if dtype == torch.float64 else outgoing_bmadx.float(),
outgoing_bmadx.to(dtype),
atol=1e-14 if dtype == torch.float64 else 0.00001,
rtol=1e-14 if dtype == torch.float64 else 1e-8,
)

0 comments on commit dc598db

Please sign in to comment.