Skip to content

Commit

Permalink
Sequence calculation fix and setter function (#55)
Browse files Browse the repository at this point in the history
* Fix index error in sequence provider

* Raster helper funtion, 3d tse

* Change raster times

* fixed tse 3d dwell time

* Fix TSE sequence

* Sequence cache as list

* Setter for sequence and parameters

* Update 3D TSE sequence

* Fixed github actions trigger

* Fixed linter findings

* Changed var name of unrolled sequence
  • Loading branch information
schote authored Feb 2, 2024
1 parent 74e2ddb commit 1270378
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Pytest

on:
push:
pull_request:
branches:
- '*'

Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/static-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
push:
branches:
- '*'
pull_request:
branches:
- '*'

jobs:
linting:
Expand Down
3 changes: 2 additions & 1 deletion examples/se_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
)

# Run the acquisition
acq_data: AcquisitionData = acq.run(parameter=params, sequence=seq)
acq.set_sequence(parameter=params, sequence=seq)
acq_data: AcquisitionData = acq.run()

# Get decimated data from acquisition data object
data = acq_data.raw.squeeze()
Expand Down
3 changes: 2 additions & 1 deletion examples/t2_measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
)

# Perform acquisition
acq_data: AcquisitionData = acq.run(parameter=params, sequence=seq)
acq.set_sequence(parameter=params, sequence=seq)
acq_data: AcquisitionData = acq.run()
data = np.mean(acq_data.raw, axis=0).squeeze()

peaks = np.max(data, axis=-1)
Expand Down
39 changes: 24 additions & 15 deletions src/console/pulseq_interpreter/sequence_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(

self.larmor_freq: float = float("nan")
self.sample_count: int = 0
self._seq: np.ndarray | None = None
self._sqnc_cache: list = []

def dict(self) -> dict:
"""Abstract method which returns variables for logging in dictionary."""
Expand Down Expand Up @@ -269,10 +269,10 @@ def calculate_gradient(self, block: SimpleNamespace, unroll_arr: np.ndarray, fov
# Both gradient types have a delay, calculate delay in number of samples
samples_delay = int(block.delay * self.spcm_freq)
# Index of this gradient, dependent on channel designation, offset of 1 to start at channel 1
idx = ["x", "y", "z"].index(block.channel) + 1
idx = ["x", "y", "z"].index(block.channel)

# Calculate gradient offset in mV
offset = unroll_arr[0] / INT16_MAX * self.output_limits[idx]
offset = unroll_arr[0] / INT16_MAX * self.output_limits[idx+1]
# Calculat waveform scaling
# scaling = self.grad_to_volt[idx] * fov_scaling
scaling = fov_scaling / (42.58e3 * self.gpa_gain[idx] * self.gradient_efficiency[idx])
Expand All @@ -282,17 +282,17 @@ def calculate_gradient(self, block: SimpleNamespace, unroll_arr: np.ndarray, fov
if block.type == "grad":
# Arbitrary gradient waveform, interpolate linearly
# This function requires float input => cast to int16 afterwards
if np.amax(waveform := block.waveform * scaling) + offset > self.output_limits[idx]:
if np.amax(waveform := block.waveform * scaling) + offset > self.output_limits[idx+1]:
raise ValueError(
"Amplitude of %s (%s) gradient exceeded output limit (%s)"
% (
block.channel,
np.amax(waveform) + offset,
self.output_limits[idx],
self.output_limits[idx+1],
)
)
# Trasnfer mV floating point waveform values to int16 if amplitude check passed
waveform *= INT16_MAX / self.output_limits[idx]
waveform *= INT16_MAX / self.output_limits[idx+1]

gradient = np.interp(
x=np.linspace(
Expand All @@ -306,11 +306,13 @@ def calculate_gradient(self, block: SimpleNamespace, unroll_arr: np.ndarray, fov

elif block.type == "trap":
# Construct trapezoidal gradient from rise, flat and fall sections
if np.amax(flat_amp := block.amplitude * scaling) + offset > self.output_limits[idx]:
raise ValueError(f"Amplitude of {block.channel} gradient exceeded max. amplitude of channel {idx}.")
if np.amax(flat_amp := block.amplitude * scaling) + offset > self.output_limits[idx+1]:
raise ValueError(
f"Amplitude of {block.channel} gradient exceeded max. amplitude {self.output_limits[idx+1]}."
)

# Trasnfer mV floating point flat amplitude to int16 if amplitude check passed
flat_amp = np.int16(flat_amp * INT16_MAX / self.output_limits[idx])
flat_amp = np.int16(flat_amp * INT16_MAX / self.output_limits[idx+1])

rise = np.linspace(
0,
Expand Down Expand Up @@ -467,6 +469,12 @@ def unroll_sequence(
As the 15th bit is not encoding the sign (as usual for int16), the values are casted to uint16 before shifting.
"""
if self._sqnc_cache:
# Reset unrolled sequence cache to free memory
print("Resetting sequence cache...")
del self._sqnc_cache
self._sqnc_cache = []

try:
# Check larmor frequency
if larmor_freq > 10e6:
Expand Down Expand Up @@ -556,7 +564,7 @@ def unroll_sequence(
)

# Save unrolled sequence in class
self._seq = np.concatenate(_seq)
self._sqnc_cache = _seq

return UnrolledSequence(
seq=_seq,
Expand Down Expand Up @@ -589,18 +597,19 @@ def plot_unrolled(
"""
fig, axis = plt.subplots(5, 1, figsize=(16, 9))

if self._seq is None:
if not self._sqnc_cache:
print("No unrolled sequence...")
return fig, axis

seq_start = int(time_range[0] * self.spcm_freq)
seq_end = int(time_range[1] * self.spcm_freq) if time_range[1] > time_range[0] else -1
samples = np.arange(self.sample_count, dtype=float)[seq_start:seq_end] * self.spcm_dwell_time * 1e3

rf_signal = self._seq[0::4][seq_start:seq_end]
gx_signal = self._seq[1::4][seq_start:seq_end]
gy_signal = self._seq[2::4][seq_start:seq_end]
gz_signal = self._seq[3::4][seq_start:seq_end]
sqnc = np.concatenate(self._sqnc_cache)
rf_signal = sqnc[0::4][seq_start:seq_end]
gx_signal = sqnc[1::4][seq_start:seq_end]
gy_signal = sqnc[2::4][seq_start:seq_end]
gz_signal = sqnc[3::4][seq_start:seq_end]

# Get digital signals
adc_gate = gx_signal.astype(np.uint16) >> 15
Expand Down
77 changes: 48 additions & 29 deletions src/console/spcm_control/acquisition_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
# Set sequence provider max. amplitude per channel according to values from tx_card
self.seq_provider.max_amp_per_channel = self.tx_card.max_amplitude

self.unrolled_sequence: UnrolledSequence | None = None
self.unrolled_seq: UnrolledSequence | None = None

# Attributes for data and dwell time of downsampled signal
self._raw: list[np.ndarray] = []
Expand Down Expand Up @@ -128,8 +128,8 @@ def _setup_logging(self, console_level: int, file_level: int) -> None:
console.setFormatter(formatter)
logging.getLogger("").addHandler(console)

def run(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> AcquisitionData:
"""Run an acquisition job.
def set_sequence(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> None:
"""Set sequence and acquisition parameter.
Parameters
----------
Expand All @@ -140,15 +140,12 @@ def run(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> Acqu
Raises
------
RuntimeError
The measurement cards are not setup properly.
AttributeError
Invalid sequence provided.
FileNotFoundError
Invalid file ending of sequence file.
"""
try:
# Check setup
if not self.is_setup:
raise RuntimeError("Measurement cards are not setup.")
# Check sequence
if isinstance(sequence, Sequence):
self.seq_provider.from_pypulseq(sequence)
Expand All @@ -159,77 +156,99 @@ def run(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> Acqu
else:
raise AttributeError("Invalid sequence, must be either string to .seq file or Sequence instance")

except (RuntimeError, FileNotFoundError, AttributeError) as err:
except (FileNotFoundError, AttributeError) as err:
self.log.exception(err, exc_info=True)
raise err

# Calculate sequence
self.unrolled_seq = None
self.log.info(
"Unrolling sequence: %s",
self.seq_provider.definitions["Name"].replace(" ", "_"),
)
sqnc: UnrolledSequence = self.seq_provider.unroll_sequence(
self.unrolled_seq = self.seq_provider.unroll_sequence(
larmor_freq=parameter.larmor_frequency,
b1_scaling=parameter.b1_scaling,
fov_scaling=parameter.fov_scaling,
grad_offset=parameter.gradient_offset,
)
# Save unrolled sequence
self.unrolled_sequence = sqnc if sqnc else None
self.parameter = parameter


def run(self) -> AcquisitionData:
"""Run an acquisition job.
Raises
------
RuntimeError
The measurement cards are not setup properly
ValueError
Missing raw data or missing averages
"""
try:
# Check setup
if not self.is_setup:
raise RuntimeError("Measurement cards are not setup.")
if self.unrolled_seq is None:
raise ValueError("No sequence set, call set_sequence() to set a sequence and acquisition parameter.")
except (RuntimeError, ValueError) as err:
self.log.exception(err, exc_info=True)
raise err

# Define timeout for acquisition process: 5 sec + sequence duration
timeout = 5 + sqnc.duration
self.log.info("Sequence duration: %s s", sqnc.duration)
timeout = 5 + self.unrolled_seq.duration
self.log.info("Sequence duration: %s s", self.unrolled_seq.duration)

self._unproc = []
self._raw = []

for k in range(parameter.num_averages):
self.log.info("Acquisition %s/%s", k + 1, parameter.num_averages)
for k in range(self.parameter.num_averages):
self.log.info("Acquisition %s/%s", k + 1, self.parameter.num_averages)

# Start masurement card operations
self.rx_card.start_operation()
time.sleep(0.5)
self.tx_card.start_operation(sqnc)
self.tx_card.start_operation(self.unrolled_seq)

# Get start time of acquisition
time_start = time.time()

while len(self.rx_card.rx_data) < sqnc.adc_count:
while len(self.rx_card.rx_data) < self.unrolled_seq.adc_count:
# Delay poll by 10 ms
time.sleep(0.01)

if len(self.rx_card.rx_data) >= sqnc.adc_count:
if len(self.rx_card.rx_data) >= self.unrolled_seq.adc_count:
break

if time.time() - time_start > timeout:
# Could not receive all the data before timeout
self.log.warning(
"Acquisition Timeout: Only received %s/%s adc events",
len(self.rx_card.rx_data),
sqnc.adc_count,
self.unrolled_seq.adc_count,
stack_info=True,
)
break

if len(self.rx_card.rx_data) > 0:
self.post_processing(parameter)
self.post_processing(self.parameter)

self.tx_card.stop_operation()
self.rx_card.stop_operation()

if parameter.averaging_delay > 0:
time.sleep(parameter.averaging_delay)
if self.parameter.averaging_delay > 0:
time.sleep(self.parameter.averaging_delay)

try:
# if not self._raw.size > 0:
if not len(self._raw) > 0:
raise ValueError("Error during post processing or readout, no raw data")
raise ValueError("No raw data acquired...")
# if len(self._raw) != parameter.num_averages:
if not all(gate.shape[0] == parameter.num_averages for gate in self._raw):
if not all(gate.shape[0] == self.parameter.num_averages for gate in self._raw):
raise ValueError(
"Could not acquire all averages. Average dimensions: %s/%s",
"Missing averages: %s/%s",
[gate.shape[0] for gate in self._raw],
parameter.num_averages,
self.parameter.num_averages,
)
except ValueError as err:
self.log.exception(err, exc_info=True)
Expand All @@ -245,8 +264,8 @@ def run(self, sequence: str | Sequence, parameter: AcquisitionParameter) -> Acqu
self.rx_card.__name__: self.rx_card.dict(),
self.seq_provider.__name__: self.seq_provider.dict()
},
dwell_time=parameter.decimation / self.f_spcm,
acquisition_parameters=parameter,
dwell_time=self.parameter.decimation / self.f_spcm,
acquisition_parameters=self.parameter,
)

def post_processing(self, parameter: AcquisitionParameter) -> None:
Expand Down
23 changes: 23 additions & 0 deletions src/console/utilities/sequences/system_settings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Global definition of system settings to be imported by sequence constructors."""

from pypulseq.opts import Opts

system = Opts(
Expand All @@ -20,3 +21,25 @@
max_slew=5000,
slew_unit="T/m/s",
)


# Helper function
def raster(val: float, precision: float) -> float:
"""Fit value to gradient raster.
Parameters
----------
val
Time value to be aligned on the raster.
precision
Raster precision, e.g. system.grad_raster_time or system.adc_raster_time
Returns
-------
Value wih given time/raster precision
"""
# return np.round(val / precision) * precision
gridded_val = round(val / precision) * precision
return gridded_val
# decimals = abs(Decimal(str(precision)).as_tuple().exponent)
# return round(gridded_val, ndigits=decimals)
Loading

0 comments on commit 1270378

Please sign in to comment.