Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorised simulations #116

Merged
merged 53 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
032ac5d
Convert beams (except for transform) to be batched
jank324 Jan 4, 2024
9f02713
Make beam transforms batch computable
jank324 Jan 5, 2024
f6c39a3
Implement tracking through drift in a batched way as an example
jank324 Jan 5, 2024
dab3f13
Fix all tests except for Cavity ones
jank324 Jan 6, 2024
a077aa2
Fix cavity tests
jank324 Jan 6, 2024
c893beb
Fix properties to be batched
jank324 Jan 6, 2024
accc009
Add some batched tests
jank324 Jan 6, 2024
5b4e39c
Add broadcasting and 100k EA test
jank324 Jan 6, 2024
94064cb
Fix examples to run with batched computations
jank324 Jan 6, 2024
4b7c15f
Fix failing NX tables test
jank324 Jan 6, 2024
df48d85
Fix batching in cavity for particle beam
jank324 Jan 6, 2024
5283bd1
Prevent NaNs in emittance computation
jank324 Jan 7, 2024
b649094
Prvent zero division in twiss parameter computations
jank324 Jan 7, 2024
5b638f9
Move emittance safety to emittance computation
jank324 Jan 7, 2024
a8c5e57
Make sure parameter beam sigmas are always strictly positive
jank324 Jan 7, 2024
892206a
Fix order of sigma safety computations
jank324 Jan 7, 2024
b097bca
Set more sensibel values for parameter beam from twiss
jank324 Jan 7, 2024
5e36218
Ensure unphysical betas don't go unnoticed
jank324 Jan 7, 2024
2a0a8e0
Merge branch 'master' into 100-batched-execution
jank324 Feb 14, 2024
25b9ead
Choose better value for the clipping
jank324 Feb 14, 2024
2ce6e31
Merge branch '100-batched-execution' of github.com:desy-ml/cheetah in…
jank324 Feb 14, 2024
2b2b882
For some elements test that parameters are broadcast correctly
jank324 Feb 14, 2024
d48bbe9
Test tracking after broadcast through EA gives same result
jank324 Feb 14, 2024
0609b45
Add `__repr__` for `CustomTransferMap`
jank324 Feb 14, 2024
d84a501
Add further tests to find bug in vectorised computations on LCLS example
jank324 Feb 14, 2024
5853043
`CustomTransferMap` elements from combination also combine name
jank324 Feb 14, 2024
4494056
Segment default name should work the same as for other elements
jank324 Feb 14, 2024
6b5ac7c
Prevent flooding of unique element names by automatic optimisations
jank324 Feb 14, 2024
0ad75ab
Minor shape fixes
jank324 Feb 14, 2024
f3ea831
Add vectorisation to speed optimisation example notebook
jank324 Mar 17, 2024
7485bfc
Update README to reflect vectorisation
jank324 Mar 17, 2024
a06be55
Merge branch 'master' into 100-batched-execution
jank324 Mar 17, 2024
a82be19
Formatting fix
jank324 Mar 17, 2024
99acb66
Half fix the broken test
jank324 Mar 17, 2024
db1c5f9
Fix merge issue in cavity transfer map
cr-xu Mar 19, 2024
1f3865b
Fix cavity batch execution
cr-xu Mar 19, 2024
958aed8
Merge branch 'master' into 100-batched-execution
jank324 Mar 27, 2024
dc4e5d7
Choose more appropriate name for test file
jank324 Mar 27, 2024
47a5497
Test that n-dimensional inputs work
jank324 Mar 28, 2024
47632b0
Attempt to fix tests that fail on GitHub but succeed locally
jank324 Mar 28, 2024
5f9984b
Merge branch 'master' into 100-batched-execution
jank324 Mar 28, 2024
c212c9e
Add changelog entry for vectorisation
jank324 Mar 28, 2024
0589ee1
Already bumb version to first v0.7 release
jank324 Mar 28, 2024
9b9abe4
Add zero length to `Element` and add new test cases, fixes #143
cr-xu Apr 5, 2024
0dbf71a
Move zero length fix entry to bug fixes
jank324 Apr 5, 2024
2a27cd7
Some vscode stuff that makes sense for others to reuse
jank324 Apr 5, 2024
6faef7d
Test that breaks lengthless element fix
jank324 Apr 5, 2024
806d191
Fix issue with lengthless fix for Screen
jank324 Apr 6, 2024
457c18a
Merge branch 'master' into 100-batched-execution
jank324 Apr 16, 2024
93e9eb6
Add notice for developers on n-dimensional properties
jank324 Apr 16, 2024
ae5b61d
Update GitHub action steps versions
jank324 Apr 16, 2024
0e36c11
Fix action upgrade to actual up-to-date actions
jank324 Apr 16, 2024
0d6ef18
Merge branch 'master' into 100-batched-execution
jank324 Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

### 🚨 Breaking Changes

- Cheetah is not vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. (see #116) (@jank324)
- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. (see #116) (@jank324)

### 🚀 Features

- `CustomTransferMap` elements created by combining multiple other elements will now reflect that in their `name` attribute (see #100) (@jank324)
- Now all `Element` have a default length of `torch.zeros((1))`, i.e. also for `Marker`, `BPM`, `Screen`, and `Aperture`.

### 🐛 Bug fixes

Expand Down
39 changes: 21 additions & 18 deletions cheetah/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Element(ABC, nn.Module):
:param name: Unique identifier of the element.
"""

length: torch.Tensor = torch.zeros((1))
jank324 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, name: Optional[str] = None) -> None:
super().__init__()

Expand Down Expand Up @@ -216,9 +218,7 @@ def from_merging_elements(
tm = torch.matmul(element.transfer_map(incoming_beam.energy), tm)
incoming_beam = element.track(incoming_beam)

combined_length = sum(
element.length for element in elements if hasattr(element, "length")
)
combined_length = sum(element.length for element in elements)

combined_name = "combined_" + "_".join(element.name for element in elements)

Expand Down Expand Up @@ -1275,7 +1275,9 @@ def track(self, incoming: Beam) -> Beam:
return deepcopy(incoming)

def broadcast(self, shape: Size) -> Element:
return self.__class__(is_active=self.is_active, name=self.name)
new_bpm = self.__class__(is_active=self.is_active, name=self.name)
new_bpm.length = self.length.repeat(shape)
return new_bpm

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]
Expand Down Expand Up @@ -1316,7 +1318,9 @@ def track(self, incoming: Beam) -> Beam:
return incoming

def broadcast(self, shape: Size) -> Element:
return self.__class__(name=self.name)
new_marker = self.__class__(name=self.name)
new_marker.length = self.length.repeat(shape)
return new_marker

@property
def is_skippable(self) -> bool:
Expand Down Expand Up @@ -1546,14 +1550,16 @@ def set_read_beam(self, value: Beam) -> None:
self.cached_reading = None

def broadcast(self, shape: Size) -> Element:
return self.__class__(
new_screen = self.__class__(
resolution=self.resolution,
pixel_size=self.pixel_size,
binning=self.binning,
misalignment=self.misalignment.repeat((*shape, 1)),
is_active=self.is_active,
name=self.name,
)
new_screen.length = self.length.repeat(shape)
return new_screen

def split(self, resolution: torch.Tensor) -> list[Element]:
return [self]
Expand Down Expand Up @@ -1678,13 +1684,15 @@ def track(self, incoming: Beam) -> Beam:
)

def broadcast(self, shape: Size) -> Element:
return self.__class__(
new_aperture = self.__class__(
x_max=self.x_max.repeat(shape),
y_max=self.y_max.repeat(shape),
shape=self.shape,
is_active=self.is_active,
name=self.name,
)
new_aperture.length = self.length.repeat(shape)
return new_aperture

def split(self, resolution: torch.Tensor) -> list[Element]:
# TODO: Implement splitting for aperture properly, for now just return self
Expand Down Expand Up @@ -2075,7 +2083,7 @@ def without_inactive_zero_length_elements(
elements=[
element
for element in self.elements
if (hasattr(element, "length") and element.length > 0.0)
if element.length > 0.0
or (hasattr(element, "is_active") and element.is_active)
or element.name in except_for
],
Expand Down Expand Up @@ -2104,7 +2112,7 @@ def inactive_elements_as_drifts(
(
element
if (hasattr(element, "is_active") and element.is_active)
or not hasattr(element, "length")
or element.length == 0.0
or element.name in except_for
else Drift(element.length)
)
Expand Down Expand Up @@ -2217,7 +2225,7 @@ def is_skippable(self) -> bool:
@property
def length(self) -> torch.Tensor:
lengths = torch.stack(
[element.length for element in self.elements if hasattr(element, "length")],
[element.length for element in self.elements],
dim=1,
)
return torch.sum(lengths, dim=1)
Expand Down Expand Up @@ -2265,10 +2273,7 @@ def split(self, resolution: torch.Tensor) -> list[Element]:
]

def plot(self, ax: matplotlib.axes.Axes, s: float) -> None:
element_lengths = [
element.length[0] if hasattr(element, "length") else 0.0
for element in self.elements
]
element_lengths = [element.length[0] for element in self.elements]
element_ss = [0] + [
sum(element_lengths[: i + 1]) for i, _ in enumerate(element_lengths)
]
Expand Down Expand Up @@ -2305,9 +2310,7 @@ def plot_reference_particle_traces(
reference_segment = deepcopy(self)
splits = reference_segment.split(resolution=torch.tensor(resolution))

split_lengths = [
split.length[0] if hasattr(split, "length") else 0.0 for split in splits
]
split_lengths = [split.length[0] for split in splits]
ss = [0] + [sum(split_lengths[: i + 1]) for i, _ in enumerate(split_lengths)]

references = []
Expand Down Expand Up @@ -2393,7 +2396,7 @@ def plot_twiss(self, beam: Beam, ax: Optional[Any] = None) -> None:
longitudinal_beams = [beam]
s_positions = [0.0]
for element in self.elements:
if not hasattr(element, "length") or element.length == 0:
if element.length == 0:
continue

outgoing = element.track(longitudinal_beams[-1])
Expand Down
31 changes: 31 additions & 0 deletions tests/test_tracking_lengthless_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

import cheetah

beam_in = cheetah.ParticleBeam.from_parameters(num_particles=100)


# Only Marker
def test_tracking_marker_only():
segment = cheetah.Segment([cheetah.Marker(name="start")])

beam_out = segment.track(beam_in)

assert torch.allclose(beam_out.particles, beam_in.particles)


# Only length-less elements between non-skippable elements
def test_tracking_lengthless_elements():
segment = cheetah.Segment(
[
cheetah.Cavity(
length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C2"
),
cheetah.Marker(name="start"),
cheetah.Cavity(
length=torch.tensor([0.1]), voltage=torch.tensor([1e6]), name="C1"
),
]
)

_ = segment.track(beam_in)
Loading