Skip to content

Commit

Permalink
Refactor measurement with iterators and parallelize
Browse files Browse the repository at this point in the history
  • Loading branch information
smu160 committed Nov 2, 2023
1 parent 0351770 commit 007e42d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 53 deletions.
10 changes: 9 additions & 1 deletion spinoza/benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use spinoza::{
core::{iqft, State},
gates::{apply, c_apply, Gate},
math::{pow2f, Float, PI},
utils::pretty_print_int,
measurement::measure_qubit,
utils::{gen_random_state, pretty_print_int},
};

fn first_rotation(circuit: &mut QuantumCircuit, nqubits: usize, angles: &mut Vec<Float>) {
Expand Down Expand Up @@ -138,6 +139,11 @@ fn pprint_int(i: u128) {
let _res = pretty_print_int(i);
}

fn measure(n: usize) {
let mut state = gen_random_state(n);
measure_qubit(&mut state, 0, true, None);
}

fn criterion_benchmark(c: &mut Criterion) {
let n = 25;

Expand Down Expand Up @@ -174,6 +180,8 @@ fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("pprint_int", |b| {
b.iter(|| pprint_int(black_box(u128::MAX)))
});

c.bench_function("measure", |b| b.iter(|| measure(black_box(n))));
}

criterion_group! {name = benches; config = Criterion::default().sample_size(100); targets = criterion_benchmark}
Expand Down
27 changes: 7 additions & 20 deletions spinoza/examples/measurement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,18 @@ use spinoza::{
config::{Config, QSArgs},
core::CONFIG,
measurement::measure_qubit,
utils::gen_random_state,
utils::{gen_random_state, pretty_print_int},
};

fn measure_qubits(n: usize) {
let mut state = gen_random_state(n);

state
.reals
.iter()
.zip(state.imags.iter())
.for_each(|(re, im)| {
println!("{re},{im}");
});

println!("----------------------------------------");
println!("{state}");

measure_qubit(&mut state, 0, true, Some(0));
println!("{state}");

measure_qubit(&mut state, 1, true, Some(0));
println!("{state}");

measure_qubit(&mut state, 2, true, Some(1));
println!("{state}");
let now = std::time::Instant::now();
for t in 0..n {
measure_qubit(&mut state, t, true, None);
}
let elapsed = now.elapsed().as_micros();
println!("measured all qubits in {} us", pretty_print_int(elapsed));
}

fn main() {
Expand Down
97 changes: 65 additions & 32 deletions spinoza/src/measurement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,87 @@
use crate::{
core::State,
gates::{apply, Gate},
math::modulus,
math::{modulus, Float},
};
use rand_distr::{Binomial, Distribution};
use rayon::prelude::*;

/// Single qubit measurement
pub fn measure_qubit(state: &mut State, target: usize, reset: bool, v: Option<u8>) -> u8 {
let mut prob0 = 0.0;
let mut prob1 = 0.0;
let num_pairs = state.len() >> 1;
let distance = 1 << target;

for i in 0..num_pairs {
let s0 = i + ((i >> target) << target);
let s1 = s0 + distance;

prob0 += modulus(state.reals[s0], state.imags[s0]).powi(2);
prob1 += modulus(state.reals[s1], state.imags[s1]).powi(2);
}
let chunk_size = 1 << (target + 1);
let dist = 1 << target;

let prob0 = state
.reals
.par_chunks_exact(chunk_size)
.zip_eq(state.imags.par_chunks_exact(chunk_size))
.map(|(reals_chunk, imags_chunk)| {
let (reals_s0, _reals_s1) = reals_chunk.split_at(dist);
let (imags_s0, _imags_s1) = imags_chunk.split_at(dist);

reals_s0
.par_iter()
.zip_eq(imags_s0.par_iter())
.with_min_len(1 << 16)
.map(|(re_s0, im_s0)| modulus(*re_s0, *im_s0).powi(2))
.sum::<Float>()
})
.sum::<Float>();

let val = if let Some(_v) = v {
assert!(_v == 0 || _v == 1);
_v
} else {
let bin = Binomial::new(1, prob1).unwrap();
let bin = Binomial::new(1, 1.0 - prob0).unwrap();
bin.sample(&mut rand::thread_rng()) as u8
};

if val == 0 {
for i in 0..num_pairs {
let s0 = i + ((i >> target) << target);
let s1 = s0 + distance;

state.reals[s0] /= prob0.sqrt();
state.imags[s0] /= prob0.sqrt();
state.reals[s1] = 0.0;
state.imags[s1] = 0.0;
}
let prob0_sqrt_recip = prob0.sqrt().recip();
state
.reals
.par_chunks_exact_mut(chunk_size)
.zip_eq(state.imags.par_chunks_exact_mut(chunk_size))
.for_each(|(reals_chunk, imags_chunk)| {
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.par_iter_mut()
.zip_eq(reals_s1.par_iter_mut())
.zip_eq(imags_s0.par_iter_mut())
.zip_eq(imags_s1.par_iter_mut())
.for_each(|(((re_s0, re_s1), im_s0), im_s1)| {
*re_s0 *= prob0_sqrt_recip;
*im_s0 *= prob0_sqrt_recip;
*re_s1 = 0.0;
*im_s1 = 0.0;
});
});
} else {
for i in 0..num_pairs {
let s0 = i + ((i >> target) << target);
let s1 = s0 + distance;

state.reals[s0] = 0.0;
state.imags[s0] = 0.0;
state.reals[s1] /= prob1.sqrt();
state.imags[s1] /= prob1.sqrt();
}
let prob1 = 1.0 - prob0;
let prob1_sqrt_recip = prob1.sqrt().recip();

state
.reals
.par_chunks_exact_mut(chunk_size)
.zip_eq(state.imags.par_chunks_exact_mut(chunk_size))
.for_each(|(reals_chunk, imags_chunk)| {
let (reals_s0, reals_s1) = reals_chunk.split_at_mut(dist);
let (imags_s0, imags_s1) = imags_chunk.split_at_mut(dist);

reals_s0
.par_iter_mut()
.zip_eq(reals_s1.par_iter_mut())
.zip_eq(imags_s0.par_iter_mut())
.zip_eq(imags_s1.par_iter_mut())
.for_each(|(((re_s0, re_s1), im_s0), im_s1)| {
*re_s1 *= prob1_sqrt_recip;
*im_s1 *= prob1_sqrt_recip;
*re_s0 = 0.0;
*im_s0 = 0.0;
});
});
if reset {
apply(Gate::X, state, target);
}
Expand Down

0 comments on commit 007e42d

Please sign in to comment.