From 15ea7540598c2890b93e85045e0274912187409f Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Tue, 24 Sep 2024 17:35:08 +0200 Subject: [PATCH] Fix broadcasting issues in elements with `"bmax"` tracking methods --- cheetah/accelerator/dipole.py | 72 +++++++++++++++++++++++++---------- cheetah/accelerator/drift.py | 9 ++++- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index fe58c71a..d6aab27c 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -184,6 +184,11 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: `ParticleBeam`. :return: Beam exiting the element. """ + # TODO: The renaming of the compinents of `incoming` to just the component name + # makes things hard to read. The resuse and overwriting of those component names + # throughout the function makes it even hard, is bad practice and should really + # be fixed! + # Compute Bmad coordinates and p0c x = incoming.x px = incoming.px @@ -219,9 +224,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z, pz, p0c, electron_mass_eV ) + # Broadcast to align their shapes so that they can be stacked + x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta) + outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + (x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype, @@ -255,48 +265,70 @@ def _bmadx_body( px_norm = torch.sqrt((1 + pz) ** 2 - py**2) # For simplicity phi1 = torch.arcsin(px / px_norm) g = self.angle / self.length - gp = g / px_norm + gp = g.unsqueeze(-1) / px_norm alpha = ( 2 - * (1 + g * x) - * torch.sin(self.angle + phi1) - * self.length - * bmadx.sinc(self.angle) - - gp * ((1 + g * x) * self.length * bmadx.sinc(self.angle)) ** 2 + * (1 + g.unsqueeze(-1) * x) + * torch.sin(self.angle.unsqueeze(-1) + phi1) + * self.length.unsqueeze(-1) + * bmadx.sinc(self.angle).unsqueeze(-1) + - gp + * ( + (1 + g.unsqueeze(-1) * x) + * self.length.unsqueeze(-1) + * bmadx.sinc(self.angle).unsqueeze(-1) + ) + ** 2 ) - x2_t1 = x * torch.cos(self.angle) + self.length**2 * g * bmadx.cosc(self.angle) + x2_t1 = x * torch.cos(self.angle.unsqueeze(-1)) + self.length.unsqueeze( + -1 + ) ** 2 * g.unsqueeze(-1) * bmadx.cosc(self.angle.unsqueeze(-1)) - x2_t2 = torch.sqrt((torch.cos(self.angle + phi1) ** 2) + gp * alpha) - x2_t3 = torch.cos(self.angle + phi1) + x2_t2 = torch.sqrt( + (torch.cos(self.angle.unsqueeze(-1) + phi1) ** 2) + gp * alpha + ) + x2_t3 = torch.cos(self.angle.unsqueeze(-1) + phi1) c1 = x2_t1 + alpha / (x2_t2 + x2_t3) c2 = x2_t1 + (x2_t2 - x2_t3) / gp - temp = torch.abs(self.angle + phi1) + temp = torch.abs(self.angle.unsqueeze(-1) + phi1) x2 = c1 * (temp < torch.pi / 2) + c2 * (temp >= torch.pi / 2) Lcu = ( - x2 - self.length**2 * g * bmadx.cosc(self.angle) - x * torch.cos(self.angle) + x2 + - self.length.unsqueeze(-1) ** 2 + * g.unsqueeze(-1) + * bmadx.cosc(self.angle.unsqueeze(-1)) + - x * torch.cos(self.angle.unsqueeze(-1)) ) - Lcv = -self.length * bmadx.sinc(self.angle) - x * torch.sin(self.angle) + Lcv = -self.length.unsqueeze(-1) * bmadx.sinc( + self.angle.unsqueeze(-1) + ) - x * torch.sin(self.angle.unsqueeze(-1)) - theta_p = 2 * (self.angle + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu)) + theta_p = 2 * ( + self.angle.unsqueeze(-1) + phi1 - torch.pi / 2 - torch.arctan2(Lcv, Lcu) + ) Lc = torch.sqrt(Lcu**2 + Lcv**2) Lp = Lc / bmadx.sinc(theta_p / 2) - P = p0c * (1 + pz) # In eV + P = p0c.unsqueeze(-1) * (1 + pz) # In eV E = torch.sqrt(P**2 + mc2**2) # In eV E0 = torch.sqrt(p0c**2 + mc2**2) # In eV beta = P / E beta0 = p0c / E0 x_f = x2 - px_f = px_norm * torch.sin(self.angle + phi1 - theta_p) + px_f = px_norm * torch.sin(self.angle.unsqueeze(-1) + phi1 - theta_p) y_f = y + py * Lp / px_norm - z_f = z + (beta * self.length / beta0) - ((1 + pz) * Lp / px_norm) + z_f = ( + z + + (beta * self.length.unsqueeze(-1) / beta0.unsqueeze(-1)) + - ((1 + pz) * Lp / px_norm) + ) return x_f, px_f, y_f, py, z_f, pz @@ -331,8 +363,8 @@ def _bmadx_fringe_linear( hy = -g * torch.tan( e - 2 * f_int * h_gap * g * (1 + torch.sin(e) ** 2) / torch.cos(e) ) - px_f = px + x * hx - py_f = py + y * hy + px_f = px + x * hx.unsqueeze(-1) + py_f = py + y * hy.unsqueeze(-1) return px_f, py_f diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index 270f923b..4438c376 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -106,9 +106,14 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: z, pz, p0c, electron_mass_eV ) + # Broadcast to align their shapes so that they can be stacked + x, px, y, py, tau, delta = torch.broadcast_tensors(x, px, y, py, tau, delta) + outgoing_beam = ParticleBeam( - torch.stack((x, px, y, py, tau, delta, torch.ones_like(x)), dim=-1), - ref_energy, + particles=torch.stack( + [x, px, y, py, tau, delta, torch.ones_like(x)], dim=-1 + ), + energy=ref_energy, particle_charges=incoming.particle_charges, device=incoming.particles.device, dtype=incoming.particles.dtype,