Skip to content

Commit

Permalink
Fix broadcasting issues in elements with "bmax" tracking methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jank324 committed Sep 24, 2024
1 parent 0d3eb38 commit 15ea754
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 22 deletions.
72 changes: 52 additions & 20 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
`ParticleBeam`.
:return: Beam exiting the element.
"""
# TODO: The renaming of the compinents of `incoming` to just the component name
# makes things hard to read. The resuse and overwriting of those component names
# throughout the function makes it even hard, is bad practice and should really
# be fixed!

# Compute Bmad coordinates and p0c
x = incoming.x
px = incoming.px
Expand Down Expand Up @@ -219,9 +224,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
z, pz, p0c, electron_mass_eV
)

# Broadcast to align their shapes so that they can be stacked
x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta)

outgoing_beam = ParticleBeam(
torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1),
ref_energy,
particles=torch.stack(
(x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
Expand Down Expand Up @@ -255,48 +265,70 @@ def _bmadx_body(
px_norm = torch.sqrt((1 + pz) ** 2 - py**2) # For simplicity
phi1 = torch.arcsin(px / px_norm)
g = self.angle / self.length
gp = g / px_norm
gp = g.unsqueeze(-1) / px_norm

alpha = (
2
* (1 + g * x)
* torch.sin(self.angle + phi1)
* self.length
* bmadx.sinc(self.angle)
- gp * ((1 + g * x) * self.length * bmadx.sinc(self.angle)) ** 2
* (1 + g.unsqueeze(-1) * x)
* torch.sin(self.angle.unsqueeze(-1) + phi1)
* self.length.unsqueeze(-1)
* bmadx.sinc(self.angle).unsqueeze(-1)
- gp
* (
(1 + g.unsqueeze(-1) * x)
* self.length.unsqueeze(-1)
* bmadx.sinc(self.angle).unsqueeze(-1)
)
** 2
)

x2_t1 = x * torch.cos(self.angle) + self.length**2 * g * bmadx.cosc(self.angle)
x2_t1 = x * torch.cos(self.angle.unsqueeze(-1)) + self.length.unsqueeze(
-1
) ** 2 * g.unsqueeze(-1) * bmadx.cosc(self.angle.unsqueeze(-1))

x2_t2 = torch.sqrt((torch.cos(self.angle + phi1) ** 2) + gp * alpha)
x2_t3 = torch.cos(self.angle + phi1)
x2_t2 = torch.sqrt(
(torch.cos(self.angle.unsqueeze(-1) + phi1) ** 2) + gp * alpha
)
x2_t3 = torch.cos(self.angle.unsqueeze(-1) + phi1)

c1 = x2_t1 + alpha / (x2_t2 + x2_t3)
c2 = x2_t1 + (x2_t2 - x2_t3) / gp
temp = torch.abs(self.angle + phi1)
temp = torch.abs(self.angle.unsqueeze(-1) + phi1)
x2 = c1 * (temp < torch.pi / 2) + c2 * (temp >= torch.pi / 2)

Lcu = (
x2 - self.length**2 * g * bmadx.cosc(self.angle) - x * torch.cos(self.angle)
x2
- self.length.unsqueeze(-1) ** 2
* g.unsqueeze(-1)
* bmadx.cosc(self.angle.unsqueeze(-1))
- x * torch.cos(self.angle.unsqueeze(-1))
)

Lcv = -self.length * bmadx.sinc(self.angle) - x * torch.sin(self.angle)
Lcv = -self.length.unsqueeze(-1) * bmadx.sinc(
self.angle.unsqueeze(-1)
) - x * torch.sin(self.angle.unsqueeze(-1))

theta_p = 2 * (self.angle + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu))
theta_p = 2 * (
self.angle.unsqueeze(-1) + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu)
)

Lc = torch.sqrt(Lcu**2 + Lcv**2)
Lp = Lc / bmadx.sinc(theta_p / 2)

P = p0c * (1 + pz) # In eV
P = p0c.unsqueeze(-1) * (1 + pz) # In eV
E = torch.sqrt(P**2 + mc2**2) # In eV
E0 = torch.sqrt(p0c**2 + mc2**2) # In eV
beta = P / E
beta0 = p0c / E0

x_f = x2
px_f = px_norm * torch.sin(self.angle + phi1 - theta_p)
px_f = px_norm * torch.sin(self.angle.unsqueeze(-1) + phi1 - theta_p)
y_f = y + py * Lp / px_norm
z_f = z + (beta * self.length / beta0) - ((1 + pz) * Lp / px_norm)
z_f = (
z
+ (beta * self.length.unsqueeze(-1) / beta0.unsqueeze(-1))
- ((1 + pz) * Lp / px_norm)
)

return x_f, px_f, y_f, py, z_f, pz

Expand Down Expand Up @@ -331,8 +363,8 @@ def _bmadx_fringe_linear(
hy = -g * torch.tan(
e - 2 * f_int * h_gap * g * (1 + torch.sin(e) ** 2) / torch.cos(e)
)
px_f = px + x * hx
py_f = py + y * hy
px_f = px + x * hx.unsqueeze(-1)
py_f = py + y * hy.unsqueeze(-1)

return px_f, py_f

Expand Down
9 changes: 7 additions & 2 deletions cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
z, pz, p0c, electron_mass_eV
)

# Broadcast to align their shapes so that they can be stacked
x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta)

outgoing_beam = ParticleBeam(
torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1),
ref_energy,
particles=torch.stack(
[x, px, y, py, tau, delta, torch.ones_like(x)], dim=-1
),
energy=ref_energy,
particle_charges=incoming.particle_charges,
device=incoming.particles.device,
dtype=incoming.particles.dtype,
Expand Down

0 comments on commit 15ea754

Please sign in to comment.