Skip to content

Commit

Permalink
Make ntt panic free
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonet committed Sep 13, 2024
1 parent a94aed7 commit da043be
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 25 deletions.
64 changes: 46 additions & 18 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ntt.fst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ let ntt_layer_int_vec_step
in
a, b <: (v_Vector & v_Vector)

let zetas_b_lemma (i:nat{i >= 0 /\ i < 128}) : Lemma
(Spec.Utils.is_i16b 1664 Libcrux_ml_kem.Polynomial.v_ZETAS_TIMES_MONTGOMERY_R.[ sz i ]) =
admit()

let ntt_at_layer_1_
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand All @@ -35,22 +39,29 @@ let ntt_at_layer_1_
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(v__layer v__initial_coefficient_bound: usize)
=
let v__zeta_i_init:usize = zeta_i in
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
(sz 16)
(fun temp_0_ temp_1_ ->
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let _:usize = temp_1_ in
true)
let round:usize = round in
v zeta_i == v v__zeta_i_init + v round * 4)
(re, zeta_i <: (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize))
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let round:usize = round in
let zeta_i:usize = zeta_i +! sz 1 in
let _:Prims.unit =
zetas_b_lemma (v zeta_i);
zetas_b_lemma (v zeta_i + 1);
zetas_b_lemma (v zeta_i + 2);
zetas_b_lemma (v zeta_i + 3)
in
let re:Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector =
{
re with
Expand Down Expand Up @@ -96,22 +107,27 @@ let ntt_at_layer_2_
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(v__layer v__initial_coefficient_bound: usize)
=
let v__zeta_i_init:usize = zeta_i in
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
(sz 16)
(fun temp_0_ temp_1_ ->
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let _:usize = temp_1_ in
true)
let round:usize = round in
v zeta_i == v v__zeta_i_init + v round * 2)
(re, zeta_i <: (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize))
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let round:usize = round in
let zeta_i:usize = zeta_i +! sz 1 in
let _:Prims.unit =
zetas_b_lemma (v zeta_i);
zetas_b_lemma (v zeta_i + 1)
in
let re:Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector =
{
re with
Expand Down Expand Up @@ -149,22 +165,24 @@ let ntt_at_layer_3_
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(v__layer v__initial_coefficient_bound: usize)
=
let v__zeta_i_init:usize = zeta_i in
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
(sz 16)
(fun temp_0_ temp_1_ ->
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let _:usize = temp_1_ in
true)
let round:usize = round in
v zeta_i == v v__zeta_i_init + v round)
(re, zeta_i <: (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize))
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let round:usize = round in
let zeta_i:usize = zeta_i +! sz 1 in
let _:Prims.unit = zetas_b_lemma (v zeta_i) in
let re:Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector =
{
re with
Expand All @@ -188,6 +206,8 @@ let ntt_at_layer_3_
let hax_temp_output:Prims.unit = () <: Prims.unit in
zeta_i, re <: (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)

#push-options "--z3rlimit 200"

let ntt_at_layer_4_plus
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand All @@ -197,30 +217,31 @@ let ntt_at_layer_4_plus
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(layer v__initial_coefficient_bound: usize)
=
let _:Prims.unit =
if true
then
let _:Prims.unit = Hax_lib.v_assert (layer >=. sz 4 <: bool) in
()
in
let step:usize = sz 1 <<! layer in
let v__zeta_i_init:usize = zeta_i in
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
(sz 128 >>! layer <: usize)
(fun temp_0_ temp_1_ ->
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let _:usize = temp_1_ in
true)
let round:usize = round in
v zeta_i == v v__zeta_i_init + v round)
(re, zeta_i <: (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize))
(fun temp_0_ round ->
let re, zeta_i:(Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector & usize) =
temp_0_
in
let round:usize = round in
let _:Prims.unit =
assert (v round < 8);
assert (v step >= 16 /\ v step <= 128);
assert (v (round *! step) >= 0 /\ v (round *! step) <= 112)
in
let zeta_i:usize = zeta_i +! sz 1 in
let offset:usize = (round *! step <: usize) *! sz 2 in
let _:Prims.unit = assert (v offset >= 0 /\ v offset <= 224) in
let offset_vec:usize = offset /! sz 16 in
let step_vec:usize = step /! sz 16 in
let re:Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector =
Expand All @@ -234,6 +255,7 @@ let ntt_at_layer_4_plus
(fun re j ->
let re:Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector = re in
let j:usize = j in
let _:Prims.unit = zetas_b_lemma (v zeta_i) in
let x, y:(v_Vector & v_Vector) =
ntt_layer_int_vec_step #v_Vector
(re.Libcrux_ml_kem.Polynomial.f_coefficients.[ j ] <: v_Vector)
Expand Down Expand Up @@ -275,6 +297,8 @@ let ntt_at_layer_4_plus
let hax_temp_output:Prims.unit = () <: Prims.unit in
zeta_i, re <: (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)

#pop-options

let ntt_at_layer_7_
(#v_Vector: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand Down Expand Up @@ -396,6 +420,8 @@ let ntt_binomially_sampled_ring_element
in
re

#push-options "--z3rlimit 200"

let ntt_vector_u
(v_VECTOR_U_COMPRESSION_FACTOR: usize)
(#v_Vector: Type0)
Expand Down Expand Up @@ -454,3 +480,5 @@ let ntt_vector_u
(Prims.unit & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
in
re

#pop-options
12 changes: 7 additions & 5 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ntt.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ val ntt_layer_int_vec_step
{| i1: Libcrux_ml_kem.Vector.Traits.t_Operations v_Vector |}
(a b: v_Vector)
(zeta_r: i16)
: Prims.Pure (v_Vector & v_Vector) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (v_Vector & v_Vector)
(requires Spec.Utils.is_i16b 3328 zeta_r)
(fun _ -> Prims.l_True)

val ntt_at_layer_1_
(#v_Vector: Type0)
Expand All @@ -23,7 +25,7 @@ val ntt_at_layer_1_
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(v__layer v__initial_coefficient_bound: usize)
: Prims.Pure (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
Prims.l_True
(requires v zeta_i < 64)
(fun _ -> Prims.l_True)

val ntt_at_layer_2_
Expand All @@ -33,7 +35,7 @@ val ntt_at_layer_2_
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(v__layer v__initial_coefficient_bound: usize)
: Prims.Pure (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
Prims.l_True
(requires v zeta_i < 96)
(fun _ -> Prims.l_True)

val ntt_at_layer_3_
Expand All @@ -43,7 +45,7 @@ val ntt_at_layer_3_
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(v__layer v__initial_coefficient_bound: usize)
: Prims.Pure (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
Prims.l_True
(requires v zeta_i < 112)
(fun _ -> Prims.l_True)

val ntt_at_layer_4_plus
Expand All @@ -53,7 +55,7 @@ val ntt_at_layer_4_plus
(re: Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
(layer v__initial_coefficient_bound: usize)
: Prims.Pure (usize & Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
Prims.l_True
(requires v layer >= 4 /\ v layer <= 7 /\ v zeta_i + v (sz 128 >>! layer) < 128)
(fun _ -> Prims.l_True)

val ntt_at_layer_7_
Expand Down
1 change: 0 additions & 1 deletion libcrux-ml-kem/proofs/fstar/extraction/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ SLOW_MODULES += Libcrux_ml_kem.Vector.Portable.Serialize.fst

ADMIT_MODULES = Libcrux_ml_kem.Ind_cca.Unpacked.fst \
Libcrux_ml_kem.Invert_ntt.fst \
Libcrux_ml_kem.Ntt.fst \
Libcrux_ml_kem.Vector.Avx2.fst \
Libcrux_ml_kem.Vector.Avx2.Sampling.fst \
Libcrux_ml_kem.Vector.Avx2.Serialize.fst \
Expand Down
33 changes: 32 additions & 1 deletion libcrux-ml-kem/src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,26 @@ use crate::{
};

#[inline(always)]
#[cfg_attr(hax, hax_lib::fstar::before("let zetas_b_lemma (i:nat{i >= 0 /\\ i < 128}) : Lemma
(Spec.Utils.is_i16b 1664 ${ZETAS_TIMES_MONTGOMERY_R}.[ sz i ]) =
admit()"))]
#[hax_lib::requires(fstar!("v ${*zeta_i} < 64"))]
pub(crate) fn ntt_at_layer_1<Vector: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<Vector>,
_layer: usize,
_initial_coefficient_bound: usize,
) {
let _zeta_i_init = *zeta_i;
// The semicolon and parentheses at the end of loop are a workaround
// for the following bug https://github.com/hacspec/hax/issues/720
for round in 0..16 {
hax_lib::loop_invariant!(|round: usize| { fstar!("v zeta_i == v $_zeta_i_init + v $round * 4") });
*zeta_i += 1;
hax_lib::fstar!("zetas_b_lemma (v zeta_i);
zetas_b_lemma (v zeta_i + 1);
zetas_b_lemma (v zeta_i + 2);
zetas_b_lemma (v zeta_i + 3)");
re.coefficients[round] = Vector::ntt_layer_1_step(
re.coefficients[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
Expand All @@ -28,16 +38,21 @@ pub(crate) fn ntt_at_layer_1<Vector: Operations>(
}

#[inline(always)]
#[hax_lib::requires(fstar!("v ${*zeta_i} < 96"))]
pub(crate) fn ntt_at_layer_2<Vector: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<Vector>,
_layer: usize,
_initial_coefficient_bound: usize,
) {
let _zeta_i_init = *zeta_i;
// The semicolon and parentheses at the end of loop are a workaround
// for the following bug https://github.com/hacspec/hax/issues/720
for round in 0..16 {
hax_lib::loop_invariant!(|round: usize| { fstar!("v zeta_i == v $_zeta_i_init + v $round * 2") });
*zeta_i += 1;
hax_lib::fstar!("zetas_b_lemma (v zeta_i);
zetas_b_lemma (v zeta_i + 1)");
re.coefficients[round] = Vector::ntt_layer_2_step(
re.coefficients[round],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
Expand All @@ -49,23 +64,28 @@ pub(crate) fn ntt_at_layer_2<Vector: Operations>(
}

#[inline(always)]
#[hax_lib::requires(fstar!("v ${*zeta_i} < 112"))]
pub(crate) fn ntt_at_layer_3<Vector: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<Vector>,
_layer: usize,
_initial_coefficient_bound: usize,
) {
let _zeta_i_init = *zeta_i;
// The semicolon and parentheses at the end of loop are a workaround
// for the following bug https://github.com/hacspec/hax/issues/720
for round in 0..16 {
hax_lib::loop_invariant!(|round: usize| { fstar!("v zeta_i == v $_zeta_i_init + v $round") });
*zeta_i += 1;
hax_lib::fstar!("zetas_b_lemma (v zeta_i)");
re.coefficients[round] =
Vector::ntt_layer_3_step(re.coefficients[round], ZETAS_TIMES_MONTGOMERY_R[*zeta_i]);
}
()
}

#[inline(always)]
#[hax_lib::requires(fstar!("Spec.Utils.is_i16b 3328 $zeta_r"))]
fn ntt_layer_int_vec_step<Vector: Operations>(
mut a: Vector,
mut b: Vector,
Expand All @@ -76,26 +96,36 @@ fn ntt_layer_int_vec_step<Vector: Operations>(
a = Vector::add(a, &t);
(a, b)
}

#[inline(always)]
#[hax_lib::fstar::options("--z3rlimit 200")]
#[hax_lib::requires(fstar!("v $layer >= 4 /\\ v $layer <= 7 /\\
v ${*zeta_i} + v (sz 128 >>! $layer) < 128"))]
pub(crate) fn ntt_at_layer_4_plus<Vector: Operations>(
zeta_i: &mut usize,
re: &mut PolynomialRingElement<Vector>,
layer: usize,
_initial_coefficient_bound: usize,
) {
debug_assert!(layer >= 4);
let step = 1 << layer;

let _zeta_i_init = *zeta_i;
// The semicolon and parentheses at the end of loop are a workaround
// for the following bug https://github.com/hacspec/hax/issues/720
for round in 0..(128 >> layer) {
hax_lib::loop_invariant!(|round: usize| { fstar!("v zeta_i == v $_zeta_i_init + v $round") });
hax_lib::fstar!("assert (v $round < 8);
assert (v $step >= 16 /\\ v $step <= 128);
assert (v ($round *! $step) >= 0 /\\ v ($round *! $step) <= 112)");
*zeta_i += 1;

let offset = round * step * 2;
hax_lib::fstar!("assert (v $offset >= 0 /\\ v $offset <= 224)");
let offset_vec = offset / 16; //FIELD_ELEMENTS_IN_VECTOR;
let step_vec = step / 16; //FIELD_ELEMENTS_IN_VECTOR;

for j in offset_vec..offset_vec + step_vec {
hax_lib::fstar!("zetas_b_lemma (v zeta_i)");
let (x, y) = ntt_layer_int_vec_step(
re.coefficients[j],
re.coefficients[j + step_vec],
Expand Down Expand Up @@ -141,6 +171,7 @@ pub(crate) fn ntt_binomially_sampled_ring_element<Vector: Operations>(
}

#[inline(always)]
#[hax_lib::fstar::options("--z3rlimit 200")]
pub(crate) fn ntt_vector_u<const VECTOR_U_COMPRESSION_FACTOR: usize, Vector: Operations>(
re: &mut PolynomialRingElement<Vector>,
) {
Expand Down

0 comments on commit da043be

Please sign in to comment.