Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BmadQuadrupole element #153

Merged
merged 39 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
ce88eeb
add `BmadQuadrupole` element
jp-ga May 8, 2024
20dbeb2
remove unused files
jp-ga May 8, 2024
b80bae8
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jp-ga May 8, 2024
ea30066
merge master into 139
jp-ga Jun 18, 2024
f74fb7c
implement bmadx quad
jp-ga Jun 20, 2024
257480d
cleanup
jp-ga Jun 20, 2024
5ce9da8
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jp-ga Jun 20, 2024
0b2ceca
merge hotfix
jp-ga Jun 20, 2024
a106cf0
run isort
jp-ga Jun 20, 2024
75c915d
run black
jp-ga Jun 20, 2024
4afbc7c
hotfix flake8
jp-ga Jun 20, 2024
0d702e6
hotfix flake8 again
jp-ga Jun 20, 2024
32c9adc
add missing docstrings and typing
jp-ga Jun 25, 2024
83c2f46
fix isort
jp-ga Jun 25, 2024
2da4d42
run black
jp-ga Jun 25, 2024
647d401
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jun 26, 2024
90e4e30
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jul 9, 2024
5c1c17b
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jul 20, 2024
ac8202a
Move Bmad-X utils to new utils directory
jank324 Jul 20, 2024
3838550
A little cleanup
jank324 Jul 20, 2024
f5ffdaf
Fix import error
jank324 Jul 20, 2024
94b65a9
Light refactoring
jank324 Jul 20, 2024
ddbfb54
Fix typo in PR template
jank324 Jul 20, 2024
ca5ca6b
Reduce cope duplication in `Quadrupole.track`
jank324 Jul 20, 2024
5ad2de2
Fix Bmad-X quadrupole dev notebook
jank324 Jul 20, 2024
073a9bc
Simplify `is_skippable`
jank324 Jul 20, 2024
c5079fd
Clean up test
jank324 Jul 20, 2024
0115d3a
Add a test that finds Ryan's error
jank324 Jul 20, 2024
340ed80
Fix vectorisation issue with Bmad-X quadrupole tracking
jank324 Jul 20, 2024
8505c4e
Rearrange test reources for Bmad-X quadrupole implementation
jank324 Jul 20, 2024
6a15c64
Add changelog entry
jank324 Jul 20, 2024
67b746c
add num_steps and tracking_method to split
jp-ga Jul 22, 2024
3ec7789
Apply suggestions from code review
cr-xu Jul 23, 2024
7cbcaf4
Update cheetah/utils/bmadx.py
cr-xu Jul 23, 2024
3c851ea
Update cheetah/utils/bmadx.py
cr-xu Jul 23, 2024
4f931b2
Apply suggestions from code review
cr-xu Jul 23, 2024
111f32c
Apply suggestions from code review
cr-xu Jul 23, 2024
bddb7c2
Update cheetah/utils/bmadx.py
cr-xu Jul 23, 2024
c1e91ed
Merge branch 'master' into 139-add-quadrupole-with-chromatic-effects
jank324 Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@
- [ ] I have run `pytest` on a machine with a CUDA GPU and made sure all tests pass (**required**).
- [ ] I have checked that the documentation builds (**required**).

Note: We are using a maximum length of 88 characters per line
Note: We are using a maximum length of 88 characters per line.

<!--- This Template is an edited version of the one from https://github.com/DLR-RM/stable-baselines3/ -->
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- `Segment`s can now be imported from Bmad to devices other than `torch.device("cpu")` and dtypes other than `torch.float32` (see #196, #206) (@jank324)
- `Screen` now offers the option to use KDE for differentiable images (see #200) (@cr-xu, @roussel-ryan)
- Moving `Element`s and `Beam`s to a different `device` and changing their `dtype` like with any `torch.nn.Module` is now possible (see #209) (@jank324)
- `Quadrupole` now supports tracking with Cheetah's matrix-based method or with Bmad's more accurate method (see #153) (@jp-ga, @jank324)

### 🐛 Bug fixes

Expand Down
127 changes: 124 additions & 3 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from typing import Optional, Union
from typing import Literal, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.particles import Beam, ParticleBeam
from cheetah.track_methods import base_rmatrix, misalignment_matrix
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, bmadx

from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = torch.tensor(
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6
)


class Quadrupole(Element):
"""
Expand All @@ -23,6 +29,9 @@ class Quadrupole(Element):
:param misalignment: Misalignment vector of the quadrupole in x- and y-directions.
:param tilt: Tilt angle of the quadrupole in x-y plane [rad]. pi/4 for
skew-quadrupole.
:param num_steps: Number of drift-kick-drift steps to use for tracking through the
element when tracking method is set to `"bmadx"`.
:param tracking_method: Method to use for tracking through the element.
:param name: Unique identifier of the element.
"""

Expand All @@ -32,6 +41,8 @@ def __init__(
k1: Optional[Union[torch.Tensor, nn.Parameter]] = None,
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
tilt: Optional[Union[torch.Tensor, nn.Parameter]] = None,
num_steps: int = 1,
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
Expand Down Expand Up @@ -64,6 +75,8 @@ def __init__(
else torch.zeros_like(self.length)
),
)
self.num_steps = num_steps
self.tracking_method = tracking_method

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
R = base_rmatrix(
Expand All @@ -81,6 +94,110 @@ def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
R = torch.einsum("...ij,...jk,...kl->...il", R_exit, R, R_entry)
return R

def track(self, incoming: Beam) -> Beam:
"""
Track particles through the quadrupole element.

:param incoming: Beam entering the element.
:return: Beam exiting the element.
"""
if self.tracking_method == "cheetah":
return super().track(incoming)
elif self.tracking_method == "bmadx":
assert isinstance(
incoming, ParticleBeam
), "Bmad-X tracking is currently only supported for `ParticleBeam`."
return self._track_bmadx(incoming)
else:
raise ValueError(
f"Invalid tracking method {self.tracking_method}. "
+ "Supported methods are 'cheetah' and 'bmadx'."
)

def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
"""
Track particles through the quadrupole element using the Bmad-X tracking method.

:param incoming: Beam entering the element. Currently only supports
`ParticleBeam`.
:return: Beam exiting the element.
"""
# Compute Bmad coordinates and p0c
mc2 = electron_mass_eV.to(
device=incoming.particles.device, dtype=incoming.particles.dtype
)
bmad_coords, p0c = bmadx.cheetah_to_bmad_coords(
incoming.particles, incoming.energy, mc2
)
x = bmad_coords[..., 0]
px = bmad_coords[..., 1]
y = bmad_coords[..., 2]
py = bmad_coords[..., 3]
z = bmad_coords[..., 4]
pz = bmad_coords[..., 5]

x_offset = self.misalignment[..., 0]
y_offset = self.misalignment[..., 1]

step_length = self.length / self.num_steps
b1 = self.k1 * self.length

# Begin Bmad-X tracking
x, px, y, py = bmadx.offset_particle_set(
x_offset, y_offset, self.tilt, x, px, y, py
)

for _ in range(self.num_steps):
rel_p = 1 + pz # Particle's relative momentum (P/P0)
k1 = b1.unsqueeze(-1) / (self.length.unsqueeze(-1) * rel_p)

tx, dzx = bmadx.calculate_quadrupole_coefficients(-k1, step_length, rel_p)
ty, dzy = bmadx.calculate_quadrupole_coefficients(k1, step_length, rel_p)

z = (
z
+ dzx[0] * x**2
+ dzx[1] * x * px
+ dzx[2] * px**2
+ dzy[0] * y**2
+ dzy[1] * y * py
+ dzy[2] * py**2
)

x_next = tx[0][0] * x + tx[0][1] * px
px_next = tx[1][0] * x + tx[1][1] * px
y_next = ty[0][0] * y + ty[0][1] * py
py_next = ty[1][0] * y + ty[1][1] * py

x, px, y, py = x_next, px_next, y_next, py_next

z = z + bmadx.low_energy_z_correction(pz, p0c, mc2, step_length)

# s = s + l
x, px, y, py = bmadx.offset_particle_unset(
x_offset, y_offset, self.tilt, x, px, y, py
)

# End of Bmad-X tracking
bmad_coords[..., 0] = x
bmad_coords[..., 1] = px
bmad_coords[..., 2] = y
bmad_coords[..., 3] = py
bmad_coords[..., 4] = z
bmad_coords[..., 5] = pz

# Convert back to Cheetah coordinates
cheetah_coords, ref_energy = bmadx.bmad_to_cheetah_coords(bmad_coords, p0c, mc2)

outgoing_beam = ParticleBeam(
cheetah_coords,
ref_energy,
particle_charges=incoming.particle_charges,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
)
return outgoing_beam

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
Expand All @@ -94,7 +211,7 @@ def broadcast(self, shape: Size) -> Element:

@property
def is_skippable(self) -> bool:
return True
return self.tracking_method == "cheetah"

@property
def is_active(self) -> bool:
Expand All @@ -109,6 +226,8 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
self.k1,
misalignment=self.misalignment,
tilt=self.tilt,
num_steps=self.num_steps,
tracking_method=self.tracking_method,
dtype=self.length.dtype,
device=self.length.device,
)
Expand All @@ -134,5 +253,7 @@ def __repr__(self) -> str:
+ f"k1={repr(self.k1)}, "
+ f"misalignment={repr(self.misalignment)}, "
+ f"tilt={repr(self.tilt)}, "
+ f"num_steps={repr(self.num_steps)}, "
+ f"tracking_method={repr(self.tracking_method)}, "
+ f"name={repr(self.name)})"
)
1 change: 1 addition & 0 deletions cheetah/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import bmadx # noqa: F401
from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401
from .unique_name_generator import UniqueNameGenerator # noqa: F401
Loading