diff --git a/cheetah/accelerator.py b/cheetah/accelerator.py index b93b6647..38508282 100644 --- a/cheetah/accelerator.py +++ b/cheetah/accelerator.py @@ -337,23 +337,23 @@ def __repr__(self) -> str: class SpaceChargeKick(Element): """ - Applies the effect of space charge over a length `length`, on the **momentum** - (i.e. divergence and energy spread) of the beam. + Applies the effect of space charge over a length `length`, on the **momentum** + (i.e. divergence and energy spread) of the beam. The positions are unmodified ; this is meant to be combined with another lattice element (e.g. `Drift`) that does modify the positions, but does not take into account space charge. - This uses the integrated Green function method - (https://journals.aps.org/prab/abstract/10.1103/PhysRevSTAB.9.044204) to compute - the effect of space charge. This is similar to the method used in Ocelot. - The main difference is that it solves the Poisson equation in the beam frame, - while here we solve a modified Poisson equation in the laboratory frame + This uses the integrated Green function method + (https://journals.aps.org/prab/abstract/10.1103/PhysRevSTAB.9.044204) to compute + the effect of space charge. This is similar to the method used in Ocelot. + The main difference is that it solves the Poisson equation in the beam frame, + while here we solve a modified Poisson equation in the laboratory frame (https://pubs.aip.org/aip/pop/article-abstract/15/5/056701/1016636/Simulation-of-beams-or-plasmas-crossing-at). The two methods are in principle equivalent. Overview of the method: - Compute the beam charge density on a grid - Convolve the charge density with a Green function (the integrated green function) - to find the potential `phi` on the grid. The convolution uses the Hockney method + to find the potential `phi` on the grid. The convolution uses the Hockney method for open boundaries (allocate 2x larger arrays and perform convolution using FFTs) - Compute the corresponding electromagnetic fields and Lorentz force on the grid - Interpolate the Lorentz force to the particles and update their momentum @@ -387,28 +387,28 @@ def __init__( self.length = torch.as_tensor(length, **self.factory_kwargs) self.grid_shape = (int(num_grid_points_x), int(num_grid_points_y), \ int(num_grid_points_s)) - self.grid_extend_x = torch.as_tensor(grid_extend_x, **self.factory_kwargs) + self.grid_extend_x = torch.as_tensor(grid_extend_x, **self.factory_kwargs) # in multiples of sigma self.grid_extend_y = torch.as_tensor(grid_extend_y, **self.factory_kwargs) self.grid_extend_s = torch.as_tensor(grid_extend_s, **self.factory_kwargs) - + def _compute_grid_dimensions(self,beam: ParticleBeam) -> torch.Tensor: - sigma_x = torch.std(beam.particles[..., 0]) - sigma_y = torch.std(beam.particles[..., 2]) - sigma_s = torch.std(beam.particles[..., 4]) - return torch.tensor([self.grid_extend_x*sigma_x, self.grid_extend_y*sigma_y\ - , self.grid_extend_s*sigma_s]) - + sigma_x = torch.std(beam.particles[:,:,0], dim=1) + sigma_y = torch.std(beam.particles[:,:,2], dim=1) + sigma_s = torch.std(beam.particles[:,:,4], dim=1) + return torch.stack([self.grid_extend_x*sigma_x, self.grid_extend_y*sigma_y\ + , self.grid_extend_s*sigma_s], dim=-1) + def _gammaref(self,beam: ParticleBeam) -> torch.Tensor: return beam.energy / rest_energy - + def _betaref(self,beam: ParticleBeam) -> torch.Tensor: gamma = self._gammaref(beam) if gamma == 0: return torch.tensor(1.0) return torch.sqrt(1 - 1 / gamma**2) - + def _deposit_charge_on_grid(self, beam: ParticleBeam, cell_size, grid_dimensions)\ -> torch.Tensor: """ @@ -416,44 +416,47 @@ def _deposit_charge_on_grid(self, beam: ParticleBeam, cell_size, grid_dimensions grid point method and weighting by the distance to the grid points. Returns a grid of charge density in C/m^3. """ - grid_shape = self.grid_shape - charge = torch.zeros(grid_shape) - - # Get particle positions and charges - particle_pos = beam.particles[..., [0, 2, 4]] - particle_charge = beam.particle_charges - normalized_pos = (particle_pos + grid_dimensions) / cell_size - - # Find the indices of the lower corners of the cells containing the particles - cell_indices = torch.floor(normalized_pos).type(torch.long) - - # Calculate the weights for all surrounding cells - offsets = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0]\ - , [1, 0, 1], [1, 1, 0], [1, 1, 1]]) - surrounding_indices = cell_indices.unsqueeze(1) + offsets - # Shape: (n_particles, 8, 3) - weights = 1 - torch.abs(normalized_pos.unsqueeze(1) - surrounding_indices) - # Shape: (n_particles, 8, 3) - cell_weights = weights.prod(dim=2) # Shape: (n_particles, 8) - - # Add the charge contributions to the cells - idx_x, idx_y, idx_s = surrounding_indices.view(-1, 3).T - # Shape: (3, n_particles*8) - # Check that particles are inside the grid - valid_mask = (idx_x >= 0) & (idx_x < grid_shape[0]) & \ - (idx_y >= 0) & (idx_y < grid_shape[1]) & \ - (idx_s >= 0) & (idx_s < grid_shape[2]) - - # Accumulate the charge contributions - indices = torch.stack([idx_x[valid_mask], idx_y[valid_mask],idx_s[valid_mask]]\ - , dim=0) # Shape: (3, n_valid) n_valid =8*n_particles - repeated_charges = particle_charge.repeat_interleave(8) # Shape:(8*n_particles) - values = (cell_weights.view(-1) * repeated_charges)[valid_mask] - charge.index_put_(tuple(indices), values, accumulate=True) - inv_cell_volume = 1 / (cell_size[0] * cell_size[1] * cell_size[2]) - - return charge * inv_cell_volume # Normalize by the cell volume - + charge = torch.zeros( (self.n_batch,) + self.grid_shape, **self.factory_kwargs ) + + # Loop over batch dimension + for i_batch in range(self.n_batch): + # Get particle positions and charges + particle_pos = beam.particles[i_batch, :, [0, 2, 4]] + particle_charge = beam.particle_charges[i_batch] + normalized_pos = (particle_pos[:, :] + grid_dimensions[i_batch, None, :]) / cell_size[i_batch, None, :] + + # Find the indices of the lower corners of the cells containing the particles + cell_indices = torch.floor(normalized_pos).type(torch.long) + + # Calculate the weights for all surrounding cells + offsets = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0]\ + , [1, 0, 1], [1, 1, 0], [1, 1, 1]]) + surrounding_indices = cell_indices[:, None, :] + offsets[None, :, :] + # Shape: (n_particles, 8, 3) + weights = 1 - torch.abs(normalized_pos[:, None, :] - surrounding_indices) + # Shape: (n_particles, 8, 3) + cell_weights = weights.prod(dim=-1) # Shape: (n_particles, 8) + + # Add the charge contributions to the cells + idx_x = surrounding_indices[:,:,0].flatten() + idx_y = surrounding_indices[:,:,1].flatten() + idx_s = surrounding_indices[:,:,2].flatten() + # Shape: (8*n_particles,) + # Check that particles are inside the grid + valid_mask = (idx_x >= 0) & (idx_x < self.grid_shape[0]) & \ + (idx_y >= 0) & (idx_y < self.grid_shape[1]) & \ + (idx_s >= 0) & (idx_s < self.grid_shape[2]) + + # Accumulate the charge contributions + repeated_charges = particle_charge.repeat_interleave(8) # Shape:(8*n_particles,) + values = (cell_weights.view(-1) * repeated_charges)[valid_mask] + charge[i_batch].index_put_( (idx_x[valid_mask], idx_y[valid_mask], idx_s[valid_mask]), values, accumulate=True) + + # End of loop over batch + inv_cell_volume = 1 / (cell_size[:,0] * cell_size[:,1] * cell_size[:,2]) + + return charge * inv_cell_volume[:, None, None, None] # Normalize by the cell volume + def _integrated_potential(self, x, y, s) -> torch.Tensor: """ @@ -469,7 +472,7 @@ def _integrated_potential(self, x, y, s) -> torch.Tensor: + x * s * torch.asinh(y / torch.sqrt(x**2 + s**2)) + x * y * torch.asinh(s / torch.sqrt(x**2 + y**2))) return G - + def _array_rho(self,beam: ParticleBeam, cell_size, grid_dimensions) ->torch.Tensor: """ @@ -481,31 +484,34 @@ def _array_rho(self,beam: ParticleBeam, cell_size, grid_dimensions) ->torch.Tens new_dims = tuple(dim * 2 for dim in grid_shape) # Create a new tensor with the doubled dimensions, filled with zeros - new_charge_density = torch.zeros(new_dims, **self.factory_kwargs) + new_charge_density = torch.zeros( (self.n_batch,) + new_dims, **self.factory_kwargs) # Copy the original charge_density values to the beginning of the new tensor - new_charge_density[..., :charge_density.shape[0], :charge_density.shape[1],\ - :charge_density.shape[2]] = charge_density - return new_charge_density - + new_charge_density[:, :charge_density.shape[1], :charge_density.shape[2],\ + :charge_density.shape[3]] = charge_density + return new_charge_density + def _IGF(self, beam: ParticleBeam, cell_size) -> torch.Tensor: """ Computes the Integrated Green Function (IGF) with periodic boundary conditions (to perform Hockney's method). """ gamma = self._gammaref(beam) - dx, dy, ds = cell_size[0], cell_size[1], cell_size[2] * gamma #scaled by gamma + dx, dy, ds = cell_size[:,0], cell_size[:,1], cell_size[:,2] * gamma #scaled by gamma num_grid_points_x, num_grid_points_y, num_grid_points_s = self.grid_shape - + # Create coordinate grids - x = torch.arange(num_grid_points_x, **self.factory_kwargs) * dx - y = torch.arange(num_grid_points_y, **self.factory_kwargs) * dy - s = torch.arange(num_grid_points_s, **self.factory_kwargs) * ds - x_grid, y_grid, s_grid = torch.meshgrid(x, y, s, indexing='ij') + x = torch.arange( num_grid_points_x, **self.factory_kwargs) + y = torch.arange( num_grid_points_y, **self.factory_kwargs) + s = torch.arange( num_grid_points_s, **self.factory_kwargs) + ix_grid, iy_grid, is_grid = torch.meshgrid(x, y, s, indexing='ij') + x_grid = ix_grid[None, :, :, :] * dx[:, None, None, None] # Shape: [n_batch, nx, ny, nz] + y_grid = iy_grid[None, :, :, :] * dy[:, None, None, None] # Shape: [n_batch, nx, ny, nz] + s_grid = is_grid[None, :, :, :] * ds[:, None, None, None] # Shape: [n_batch, nx, ny, nz] # Compute the Green's function values G_values = ( - self._integrated_potential(x_grid + 0.5 * dx, y_grid + 0.5 * dy,\ + self._integrated_potential( x_grid + 0.5 * dx, y_grid + 0.5 * dy,\ s_grid + 0.5 * ds) - self._integrated_potential(x_grid - 0.5 * dx, y_grid + 0.5 * dy,\ s_grid + 0.5 * ds) @@ -524,29 +530,29 @@ def _IGF(self, beam: ParticleBeam, cell_size) -> torch.Tensor: ) # Initialize the grid with double dimensions - green_func = torch.zeros(2 * num_grid_points_x, 2 * num_grid_points_y,\ + green_func = torch.zeros( self.n_batch, 2 * num_grid_points_x, 2 * num_grid_points_y,\ 2 * num_grid_points_s, **self.factory_kwargs) # Fill the grid with G_values and its periodic copies - green_func[:num_grid_points_x, :num_grid_points_y, :num_grid_points_s]\ + green_func[:, :num_grid_points_x, :num_grid_points_y, :num_grid_points_s]\ = G_values - green_func[num_grid_points_x+1:, :num_grid_points_y, :num_grid_points_s]\ - = G_values[1:,:,:].flip(dims=[0]) #Reverse x, excluding the first element - green_func[:num_grid_points_x, num_grid_points_y+1:, :num_grid_points_s]\ - = G_values[:, 1:,:].flip(dims=[1])#Reverse y, excluding the first element - green_func[:num_grid_points_x, :num_grid_points_y, num_grid_points_s+1:]\ - = G_values[:, :, 1:].flip(dims=[2])#Reverse s,excluding the first element - green_func[num_grid_points_x+1:, num_grid_points_y+1:, :num_grid_points_s]\ - = G_values[1:, 1:,:].flip(dims=[0, 1]) # Reverse the x and y dimensions - green_func[:num_grid_points_x, num_grid_points_y+1:, num_grid_points_s+1:]\ - = G_values[:, 1:, 1:].flip(dims=[1, 2]) # Reverse the y and s dimensions - green_func[num_grid_points_x+1:, :num_grid_points_y, num_grid_points_s+1:]\ - = G_values[1:, :, 1:].flip(dims=[0, 2]) # Reverse the x and s dimensions - green_func[num_grid_points_x+1:, num_grid_points_y+1:, num_grid_points_s+1:]\ - = G_values[1:, 1:, 1:].flip(dims=[0, 1, 2]) # Reverse all dimensions + green_func[:, num_grid_points_x+1:, :num_grid_points_y, :num_grid_points_s]\ + = G_values[:, 1:, :, :].flip(dims=[1]) #Reverse x, excluding the first element + green_func[:, :num_grid_points_x, num_grid_points_y+1:, :num_grid_points_s]\ + = G_values[:, :, 1:, :].flip(dims=[2])#Reverse y, excluding the first element + green_func[:, :num_grid_points_x, :num_grid_points_y, num_grid_points_s+1:]\ + = G_values[:, :, :, 1:].flip(dims=[3])#Reverse s,excluding the first element + green_func[:, num_grid_points_x+1:, num_grid_points_y+1:, :num_grid_points_s]\ + = G_values[:, 1:, 1:, :].flip(dims=[1, 2]) # Reverse the x and y dimensions + green_func[:, :num_grid_points_x, num_grid_points_y+1:, num_grid_points_s+1:]\ + = G_values[:, :, 1:, 1:].flip(dims=[2, 3]) # Reverse the y and s dimensions + green_func[:, num_grid_points_x+1:, :num_grid_points_y, num_grid_points_s+1:]\ + = G_values[:, 1:, :, 1:].flip(dims=[1, 3]) # Reverse the x and s dimensions + green_func[:, num_grid_points_x+1:, num_grid_points_y+1:, num_grid_points_s+1:]\ + = G_values[:, 1:, 1:, 1:].flip(dims=[1, 2, 3]) # Reverse all dimensions return green_func - + def _solve_poisson_equation(self, beam: ParticleBeam, cell_size, grid_dimensions)\ -> torch.Tensor: #works only for ParticleBeam at this stage @@ -554,15 +560,15 @@ def _solve_poisson_equation(self, beam: ParticleBeam, cell_size, grid_dimensions Solves the Poisson equation for the given charge density, using FFT convolution. """ charge_density = self._array_rho(beam, cell_size, grid_dimensions) - charge_density_ft = torch.fft.fftn(charge_density) + charge_density_ft = torch.fft.fftn(charge_density, dim=[1, 2, 3]) integrated_green_function = self._IGF(beam, cell_size) - integrated_green_function_ft = torch.fft.fftn(integrated_green_function) + integrated_green_function_ft = torch.fft.fftn(integrated_green_function, dim=[1, 2, 3]) potential_ft = charge_density_ft * integrated_green_function_ft - potential = (1/(4*torch.pi*epsilon_0))*torch.fft.ifftn(potential_ft).real + potential = (1/(4*torch.pi*epsilon_0))*torch.fft.ifftn(potential_ft, dim=[1, 2, 3]).real # Return the physical potential - return potential[:charge_density.shape[0]//2, :charge_density.shape[1]//2,\ - :charge_density.shape[2]//2] + return potential[:, :charge_density.shape[1]//2, :charge_density.shape[2]//2,\ + :charge_density.shape[3]//2] def _E_plus_vB_field(self, beam: ParticleBeam, cell_size, grid_dimensions)\ @@ -579,44 +585,43 @@ def _E_plus_vB_field(self, beam: ParticleBeam, cell_size, grid_dimensions)\ else torch.tensor(0.0) ) potential = self._solve_poisson_equation(beam, cell_size, grid_dimensions) - + grad_x = torch.zeros_like(potential) grad_y = torch.zeros_like(potential) grad_s = torch.zeros_like(potential) - # Compute the gradients of the potential, using central differences, with 0 + # Compute the gradients of the potential, using central differences, with 0 #boundary conditions. - grad_x[1:-1, :, :] = ( potential[2:, :, :] - potential[:-2, :, :] )\ - * (0.5 * inv_cell_size[0]) - grad_y[:, 1:-1, :] = ( potential[:, 2:, :] - potential[:, :-2, :] )\ - * (0.5 * inv_cell_size[1]) - grad_s[:, :, 1:-1] = ( potential[:, :, 2:] - potential[:, :, :-2] )\ - * (0.5 * inv_cell_size[2]) + grad_x[:, 1:-1, :, :] = ( potential[:, 2:, :, :] - potential[:, :-2, :, :] )\ + * (0.5 * inv_cell_size[:, 0, None, None, None]) + grad_y[:, :, 1:-1, :] = ( potential[:, :, 2:, :] - potential[:, :, :-2, :] )\ + * (0.5 * inv_cell_size[:, 1, None, None, None]) + grad_s[:, :, :, 1:-1] = ( potential[:, :, :, 2:] - potential[:, :, :, :-2] )\ + * (0.5 * inv_cell_size[:, 2, None, None, None]) # Scale the gradients with lorentz factor - grad_x = -igamma2*grad_x - grad_y = -igamma2*grad_y - grad_s = -igamma2*grad_s + grad_x = -igamma2[:, None, None, None]*grad_x + grad_y = -igamma2[:, None, None, None]*grad_y + grad_s = -igamma2[:, None, None, None]*grad_s return grad_x, grad_y, grad_s def _cheetah_to_moments(self, beam: ParticleBeam) -> torch.Tensor: """ - Converts the Cheetah particle beam parameters to the moments in SI units used + Converts the Cheetah particle beam parameters to the moments in SI units used in the space charge solver. """ - N = beam.particles.shape[0] moments = beam.particles gammaref = self._gammaref(beam) betaref = self._betaref(beam) p0 = gammaref*betaref*electron_mass*c - gamma = gammaref*(torch.ones(N)+beam.particles[:,5]*betaref) + gamma = gammaref[:, None] * ( torch.ones(moments.shape[:-1]) + beam.particles[:,:,5]*betaref[:, None] ) beta = torch.sqrt(1 - 1 / gamma**2) p = gamma*electron_mass*beta*c - moments[:,1] = p0*moments[:,1] - moments[:,3] = p0*moments[:,3] - moments[:,4] = -betaref*moments[:,4] - moments[:,5] = torch.sqrt(p**2 - moments[:,1]**2 - moments[:,3]**2) + moments[:,:,1] = p0[:, None] * moments[:,:,1] + moments[:,:,3] = p0[:, None] * moments[:,:,3] + moments[:,:,4] = -betaref[:, None] * moments[:,:,4] + moments[:,:,5] = torch.sqrt(p**2 - moments[:,:,1]**2 - moments[:,:,3]**2) def _moments_to_cheetah(self, beam: ParticleBeam) \ -> torch.Tensor: @@ -624,70 +629,74 @@ def _moments_to_cheetah(self, beam: ParticleBeam) \ Converts the moments in SI units to the Cheetah particle beam parameters. """ moments = beam.particles - N = moments.shape[0] gammaref = self._gammaref(beam) betaref = self._betaref(beam) p0 = gammaref*betaref*electron_mass*c - p = torch.sqrt(moments[:,1]**2 + moments[:,3]**2 + moments[:,5]**2) + p = torch.sqrt(moments[:,:,1]**2 + moments[:,:,3]**2 + moments[:,:,5]**2) gamma = torch.sqrt(1 + (p / (electron_mass*c))**2) - moments[:,1] = moments[:,1] / p0 - moments[:,3] = moments[:,3] / p0 - moments[:,4] = -moments[:,4] / betaref - moments[:,5] = (gamma-gammaref*torch.ones(N))/(betaref*gammaref) + moments[:,:,1] = moments[:,:,1] / p0[:, None] + moments[:,:,3] = moments[:,:,3] / p0[:, None] + moments[:,:,4] = -moments[:,:,4] / betaref[:, None] + moments[:,:,5] = (gamma-gammaref*torch.ones(gamma.shape))/((betaref*gammaref)[:, None]) def _compute_forces(self, beam: ParticleBeam, cell_size, grid_dimensions)\ -> torch.Tensor: """ - Interpolates the space charge force from the grid onto the macroparticles. + Interpolates the space charge force from the grid onto the macroparticles. Reciprocal function of _deposit_charge_on_grid. """ grad_x, grad_y, grad_z = self._E_plus_vB_field(beam,cell_size, grid_dimensions) grid_shape = self.grid_shape - particle_pos = beam.particles[:, [0, 2, 4]] - normalized_pos = (particle_pos + grid_dimensions) / cell_size - - # Find the indices of the lower corners of the cells containing the particles - cell_indices = torch.floor(normalized_pos).type(torch.long) - - # Calculate the weights for all surrounding cells - offsets = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0],\ - [1, 0, 1], [1, 1, 0], [1, 1, 1]]) - surrounding_indices =cell_indices.unsqueeze(1)+offsets #Shape:(n_particles,8,3) - - weights = 1 - torch.abs(normalized_pos.unsqueeze(1) - surrounding_indices) - # Shape: (n_particles, 8, 3) - cell_weights = weights.prod(dim=2) # Shape: (n_particles, 8) - - # Extract forces from the grids - idx_x,idx_y,idx_s = surrounding_indices.view(-1, 3).T #Shape: (3,n_particles*8) - valid_mask = (idx_x >= 0) & (idx_x < grid_shape[0]) & \ - (idx_y >= 0) & (idx_y < grid_shape[1]) & \ - (idx_s >= 0) & (idx_s < grid_shape[2]) - - valid_indices = torch.stack([idx_x[valid_mask], idx_y[valid_mask],\ - idx_s[valid_mask]], dim=0) - - Fx_values = grad_x[tuple(valid_indices)] - Fy_values = grad_y[tuple(valid_indices)] - Fz_values = grad_z[tuple(valid_indices)] - - # Compute interpolated forces - interpolated_forces = torch.zeros((particle_pos.shape[0], 3)) - valid_cell_weights = cell_weights.view(-1)[valid_mask]*elementary_charge - values_x = valid_cell_weights * Fx_values - values_y = valid_cell_weights * Fy_values - values_z = valid_cell_weights * Fz_values - - indices = torch.arange(particle_pos.shape[0]).repeat_interleave(8)[valid_mask] - interpolated_forces.index_add_(0, indices, torch.stack([values_x,\ - torch.zeros_like(values_x), torch.zeros_like(values_x)], dim=1)) - interpolated_forces.index_add_(0,indices,torch.stack\ - ([torch.zeros_like(values_y), values_y, torch.zeros_like(values_y)],dim=1)) - interpolated_forces.index_add_(0, indices, torch.stack(\ - [torch.zeros_like(values_z), torch.zeros_like(values_z), values_z], dim=1)) + n_particles = beam.particles.shape[1] + interpolated_forces = torch.zeros( (self.n_batch, n_particles, 3), **self.factory_kwargs ) + + # Loop over batch dimension + for i_batch in range(self.n_batch): + + # Get particle positions + particle_pos = beam.particles[i_batch, :, [0, 2, 4]] + normalized_pos = (particle_pos[:, :] + grid_dimensions[i_batch, None, :]) / cell_size[i_batch, None, :] + + # Find the indices of the lower corners of the cells containing the particles + cell_indices = torch.floor(normalized_pos).type(torch.long) + + # Calculate the weights for all surrounding cells + offsets = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0],\ + [1, 0, 1], [1, 1, 0], [1, 1, 1]]) + surrounding_indices = cell_indices[:, None, :] + offsets[None, :, :] # Shape:(n_particles,8,3) + # Shape: (n_particles, 8, 3) + weights = 1 - torch.abs(normalized_pos[:, None, :] - surrounding_indices) + # Shape: (n_particles, 8, 3) + cell_weights = weights.prod(dim=-1) # Shape: (n_particles, 8) + + # Extract forces from the grids + idx_x, idx_y, idx_s = surrounding_indices.view(-1, 3).T #Shape: (3,n_particles*8) + valid_mask = (idx_x >= 0) & (idx_x < grid_shape[0]) & \ + (idx_y >= 0) & (idx_y < grid_shape[1]) & \ + (idx_s >= 0) & (idx_s < grid_shape[2]) + + valid_indices = ( idx_x[valid_mask], idx_y[valid_mask], idx_s[valid_mask] ) + Fx_values = grad_x[ i_batch ][ valid_indices ] + Fy_values = grad_y[ i_batch ][ valid_indices ] + Fz_values = grad_z[ i_batch ][ valid_indices ] + + # Compute interpolated forces + valid_cell_weights = cell_weights.view(-1)[valid_mask]*elementary_charge + values_x = valid_cell_weights * Fx_values + values_y = valid_cell_weights * Fy_values + values_z = valid_cell_weights * Fz_values + + indices = torch.arange(n_particles).repeat_interleave(8)[valid_mask] + interpolated_F = interpolated_forces[i_batch] + interpolated_F.index_add_(0, indices, torch.stack([values_x,\ + torch.zeros_like(values_x), torch.zeros_like(values_x)], dim=1)) + interpolated_F.index_add_(0,indices,torch.stack\ + ([torch.zeros_like(values_y), values_y, torch.zeros_like(values_y)],dim=1)) + interpolated_F.index_add_(0, indices, torch.stack(\ + [torch.zeros_like(values_z), torch.zeros_like(values_z), values_z], dim=1)) return interpolated_forces - + def track(self, incoming: ParticleBeam) -> ParticleBeam: """ @@ -708,6 +717,10 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: device=incoming.particles.device, dtype=incoming.particles.dtype, ) + # Flatten the batch dimensions (to simplify later calculation, is undone at the end of `track`) + n_particles = outcoming.particles.shape[-2] + outcoming.particles.reshape( (-1, n_particles, 7) ) + self.n_batch = outcoming.particles.shape[0] # Compute useful quantities grid_dimensions = self._compute_grid_dimensions(outcoming) cell_size = 2*grid_dimensions / torch.tensor(self.grid_shape) @@ -716,17 +729,13 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: self._cheetah_to_moments(outcoming) particles = outcoming.particles forces = self._compute_forces(outcoming, cell_size, grid_dimensions) - particles[:,1] += forces[:,0]*dt - particles[:,3] += forces[:,1]*dt - particles[:,5] += forces[:,2]*dt + particles[:,:,1] += forces[:,:,0]*dt + particles[:,:,3] += forces[:,:,1]*dt + particles[:,:,5] += forces[:,:,2]*dt self._moments_to_cheetah(outcoming) - return ParticleBeam( - outcoming.particles, - incoming.energy, - particle_charges=outcoming.particle_charges, - device=particles.device, - dtype=particles.dtype, - ) + # Unflatten the batch dimensions + outcoming.particles.reshape( incoming.particles.shape ) + return outcoming else: raise TypeError(f"Parameter incoming is of invalid type {type(incoming)}") diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 9a9eff21..990f7944 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -4,10 +4,10 @@ def test_cold_uniform_beam_expansion(): """ - Tests that that a cold uniform beam doubles in size in both dimensions when - travelling through a drift section with space_charge. (cf ImpactX test: + Tests that that a cold uniform beam doubles in size in both dimensions when + travelling through a drift section with space_charge. (cf ImpactX test: https://impactx.readthedocs.io/en/latest/usage/examples/cfchannel/README.html#constant-focusing-channel-with-space-charge) - See Free Expansion of a Cold Uniform Bunch in + See Free Expansion of a Cold Uniform Bunch in https://accelconf.web.cern.ch/hb2023/papers/thbp44.pdf. """ @@ -51,13 +51,13 @@ def test_cold_uniform_beam_expansion(): cheetah.Drift(L/6) ] ) - outgoing_beam = segment_space_charge.track(incoming) + outgoing_beam = segment_space_charge.track(incoming) # Final beam properties sig_xo = outgoing_beam.sigma_x sig_yo = outgoing_beam.sigma_y sig_so = outgoing_beam.sigma_s - + torch.set_printoptions(precision=16) assert torch.isclose(sig_xo,2*sig_xi,rtol=2e-2,atol=0.0) assert torch.isclose(sig_yo,2*sig_yi,rtol=2e-2,atol=0.0) @@ -76,7 +76,7 @@ def test_incoming_beam_not_modified(): # Initial beam properties incoming_particles0 = incoming_beam.particles - L=torch.tensor(1.0) + L=torch.tensor([1.0]) segment_space_charge = cheetah.Segment( elements=[ cheetah.Drift(L/6), @@ -89,8 +89,8 @@ def test_incoming_beam_not_modified(): ] ) # Calling the track method - outgoing_beam = segment_space_charge.track(incoming_beam) - + outgoing_beam = segment_space_charge.track(incoming_beam) + # Final beam properties incoming_particles1 = incoming_beam.particles