Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use RISC Zero BigInt multiplier to accelerate p256 within the zkVM guest #5

Open
wants to merge 12 commits into
base: risczero
Choose a base branch
from

Conversation

tsumian
Copy link

@tsumian tsumian commented Jul 29, 2024

Description

Building on risc0/risc0#466 and referencing #1, this PR enables the use of the RISC Zero 256-bit modular multiplication accelerator within guest code for p256 arithmetic, including ECDSA.

A key application, ECDSA verification is accelerated significantly from a little over 13M cycles without acceleration support to about 2M cycles, which is about a 7 times speed up (similar to speedup achieved for k256).

Based on p256@v0.13.2

Average Cycle Length

Curve Avg. Cycle Length
p256 unoptimised 13 005 306
p256 optimised 2 005 676

Profiling

Unoptimised version

image

Optimised version

Screenshot 2024-07-24 at 11 42 41 AM

To Do

  • Run a batch of proves in zkvm to get an average
  • Run profiling on unoptimised version of k256 and p256 to understand the discrepancy between the cycle counts
  • Run profiling on optimised version of k256 and p256 to contrast difference with unoptimised version
  • Removed Montgomery form to use standard form

Testing

Copy link
Member

@tzerrell tzerrell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, this will be good to have acceleration for!

I mentioned a few issues @nategraf & I found inline, including one that could enable incorrect calculations via denormalized inputs. I believe @nategraf is taking a deeper look later today and will probably have more suggestions for how to address this.

pub(super) fn mul_single(a: U256, rhs: u32) -> U256 {
let mut result = U256::ZERO;
for _i in 0..rhs {
result = add(a, a)
Copy link
Member

@tzerrell tzerrell Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the intent is to add a to itself rhs times, but I don't think this does that -- it looks like it sets result to a + a repeatedly, so this will in effect always multiply by 2.

Presumably you wanted the below suggestion, although I'm not convinced an algorithm that could pontentially run >2^31 adds (since rhs is a u32) is a good idea.

Suggested change
result = add(a, a)
result = add(result, a)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right about this. I have changed the current implementation to the suggested change. Although I have another implementation in mind that may solve the second part of your comment. Do let me know what you think about this:

pub(super) fn mul_single(a: U256, rhs: u32) -> U256 {
    let a = a.as_words();
    let rhs_u64: u64 = rhs as u64;
    U256::from_words([
        a[0] * rhs_u64,
        a[1] * rhs_u64,
        a[2] * rhs_u64,
        a[3] * rhs_u64,
    ])
}

for field64.rs and

pub(super) fn mul_single(a: U256, rhs: u32) -> U256 {
    let a = a.as_words();
    U256::from_words([
        a[0] * rhs,
        a[1] * rhs,
        a[2] * rhs,
        a[3] * rhs,
        a[4] * rhs,
        a[5] * rhs,
        a[6] * rhs,
        a[7] * rhs,
    ])
}

for field32.rs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only work if there are no carries, and except in special cases there will be carries between the limbs. You could take a similar approach but also track the carries as a separate result, then reduce the carries mod the prime and then do a modular addition of the reduced carries and the non-carries (or I suppose you could skip the "reduce the carries mod the prime" step, but then when you're doing the modular addition you'll have to handle the carry term being a U288 (32-bit) or U320 (64-bit))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we need to keep track of the carries and mod the prime at the end, I think we can adopt a similar approach to what was implemented for add, and replace adc with mac so that we can track the carries. Then we perform the modulo at the end with sub_inner

pub(super) fn mul_single(a: U256, rhs: u32) -> U256 {
    let a_limbs = a.as_limbs();
    let rhs_limb = Limb::from_u32(rhs);
    let (w0, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, Limb::ZERO);
    let (w1, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry);
    let (w2, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry);
    let (w3, w4) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry);

    // Attempt to subtract the modulus from carry, to ensure the result is in the field
    let modulus = MODULUS.0.as_limbs();

    let (result, _) = sub_inner(
        [w0, w1, w2, w3, w4],
        [modulus[0], modulus[1], modulus[2], modulus[3], Limb::ZERO],
    );
    U256::new([result[0], result[1], result[2], result[3]])
}


let modulus = MODULUS.0.as_limbs();

let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: BitAnd is actually more expensive in the RISC Zero zkVM than multiply, so prefer multiplying by 1 to bitand'ing with 0xfff...fff

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too sure about the fix here can you elaborate a little more on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fix is not required, this is a minor performance detail.

That said, if you want to upgrade the performance here, then you should know that counterintuitively BitAnd is one of the slowest zkVM operations, and it is faster to multiply instead. So, instead of doing a bunch of modulus[i].binand(borrow) ops, instead multiply by 1 if borrow is nonzero, or 0 otherwise. I.e., compute let borrowed = !borrow.is_zero() (I think that's the right syntax, anyway) and then replace modulus[i].bitand(borrow) with modulus[i] * borrowed).

// Attempt to subtract the modulus, to ensure the result is in the field.
let modulus = MODULUS.0.as_limbs();

let (result, _) = sub_inner(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this add is run on denormalized inputs, it can result in incorrect outputs. For example, if a = b = 2^256-1 then when you subtract off the modulus in sub_inner you are left with something a little larger than 2^256. This means the final borrow in the sub_inner algorithm will be 0x0000...01, which will then get bitanded with each limb, which will clearly not be the correct answer.

With an honest host this shouldn't come up (because they'd just use normalized values), but a dishonest host could use denormalized values to "prove" a calculation that actually produces this incorrect answer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I observed that in the k256 implementation we have a normalize() function that checks if the value is fully normalized. What I have done here is to implement the same function and check if a and b are fully normalized first before performing the addition, to avoid the case of a dishonest host. Will adding this additional check help instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check that these input values are normalized (or know by construction that they are normalized) then the algorithm you have here is safe. So yes adding the normalization check addresses this.

@nategraf
Copy link

nategraf commented Aug 2, 2024

Looking this over with @tzerrell this morning, we focused mostly on the field implementation. We found a few issues Tim wrote up here, but overall the field implementation itself looks good.

Reviewing the rest of the code will be a bit difficult as currently factored. I appears to be copied from elsewhere in the repo, the primeorder crate if I'm not mistaken. The fact that it's copied into a new file, and that there are small changes like using .mul instead of * makes it harder to compare them side-by-side to determine the actual diff. I'll spend this afternoon tracking down the exact changes, so I can review just those rather than trying to find any issues across all the implemented code from scratch.

As a maintenance concern, we try to make sure our patches have a diff that is easy to rebase onto upstream revisions. In the current form, I anticipate merging with upstream will be difficult. Since we are the maintainers, I am happy to refactor until it matches the structure we feel most fits our workflow. In particular, I'm going to open a PR against this one to move some of the copied code back into the origonal locations and apply the patches there instead. Does that sound good to you?

@tsumian
Copy link
Author

tsumian commented Aug 5, 2024

@nategraf the affine and projective files are copied from primeorder crate. I have implemented some changes with regards to the comments above. Go ahead and open a PR against this one if you need to for maintainability. Let me know if you require more changes or clarification about the implementation 😄

@tsumian tsumian requested a review from tzerrell August 5, 2024 03:43
Copy link
Member

@tzerrell tzerrell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!


let modulus = MODULUS.0.as_limbs();

let (w0, carry) = w0.adc(modulus[0].bitand(borrow), Limb::ZERO);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fix is not required, this is a minor performance detail.

That said, if you want to upgrade the performance here, then you should know that counterintuitively BitAnd is one of the slowest zkVM operations, and it is faster to multiply instead. So, instead of doing a bunch of modulus[i].binand(borrow) ops, instead multiply by 1 if borrow is nonzero, or 0 otherwise. I.e., compute let borrowed = !borrow.is_zero() (I think that's the right syntax, anyway) and then replace modulus[i].bitand(borrow) with modulus[i] * borrowed).

// Attempt to subtract the modulus, to ensure the result is in the field.
let modulus = MODULUS.0.as_limbs();

let (result, _) = sub_inner(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you check that these input values are normalized (or know by construction that they are normalized) then the algorithm you have here is safe. So yes adding the normalization check addresses this.

pub(super) fn mul_single(a: U256, rhs: u32) -> U256 {
let mut result = U256::ZERO;
for _i in 0..rhs {
result = add(a, a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only work if there are no carries, and except in special cases there will be carries between the limbs. You could take a similar approach but also track the carries as a separate result, then reduce the carries mod the prime and then do a modular addition of the reduced carries and the non-carries (or I suppose you could skip the "reduce the carries mod the prime" step, but then when you're doing the modular addition you'll have to handle the carry term being a U288 (32-bit) or U320 (64-bit))

@tsumian tsumian requested a review from tzerrell August 6, 2024 04:07
// Attempt to subtract the modulus, to ensure the result is in the field.
let modulus = MODULUS.0.as_limbs();

let (result, _) = sub_inner(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this step works correctly for a multiply, even though it worked for add. This sub_inner call subtracts off the modulus, and then the internal logic of sub_inner uses the highest-order limb as a bitand mask with the modulus and adds that back on. In the case of add, this works well -- the highest order limb can be either 0 or 1 after the addition, and once the modulus is subtracted off it can only be 0 or -1 (in the normalized inputs case, anyway) -- and -1 is treated as 0xffffffff. But in this case, the highest order limb can be almost anything (for instance, consider what happens if you multiply modulus - 1 by 2^30). Then the sub_inner logic will be bitanding the modulus with some arbitrary u32, and adding that back on. This will produce a result not at all related to the intended multiply.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean here. Thanks for pointing out the difference! I have attempted to change it into the implementation that you have suggested previously with the help of U256::const_rem function

  1. First track the carries
  2. Reduce the carries by modding with prime
  3. Modular addition of reduced carries with non-carries
pub(super) fn mul_single(a: U256, rhs: u32) -> U256 {
    let a_limbs = a.as_limbs();
    let rhs_limb = Limb::from_u32(rhs);
    let (w0, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, Limb::ZERO);
    let (w1, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry);
    let (w2, carry) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry);
    let (w3, w4) = Limb::ZERO.mac(a_limbs[0], rhs_limb, carry);

    // Reduce the carry mod prime
    let carry = U256::from(w4);
    let (reduced_carry, _) = carry.const_rem(&MODULUS.0);

    // Modular addition of non-carry and reduced carry
    let non_carries = U256::new([w0, w1, w2, w3]);
    add(non_carries, reduced_carry)
}

Copy link
Member

@tzerrell tzerrell Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think this doesn't quite work either -- w4 is a 64-bit limb here (or 32-bit limb in the other case), and so setting carry = U256::from(w4) gets you something already smaller than MODULUS (and so the reduce is a no-op) but not in the right way -- it's reduced by 2^256, not the modulus.

I have two approaches in mind that I think could work (there are probably more):

  1. Create carry as a larger type (U512?), set it to w4 << 256, and reduce
  2. Note that since 2^256 is slightly larger than MODULUS, you know that 2^256 reduced by MODULUS is 2^256 - MODULUS. Thus, w4 << 256 mod MODULUS is just w4 * (2^256 - MODULUS). This won't even need to be reduced, as both w4 and 2^256 - MODULUS fit within a limb, and thus their product fits within 2 limbs i.e. is less than the modulus. (You do still need to do a modular addition to combine with the non-carry result.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am looking more at the second approach because I am not too sure how we can define a larger type for the carry if we were to continue using the mac function.

However, I have some clarifications regarding the second approach. The MODULUS of p256 = 2^{224}(2^{32} − 1) + 2^{192} + 2^{96} − 1 and from my calculations I believe 2^256 - MODULUS is equal to 0x00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001 in hexadecimal. Which, when converted to limbs will not fit within a single limb.

From your second approach, I was thinking of implementing something like this in code. After tracking the carries using mac, we can first get 2^256 - MODULUS which is the hex value and convert it into limbs

// Define 2^256 - MODULUS
let subtracted_result_str: &str =
    "00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001";

let subtracted_result = U256::from_be_hex(subtracted_result_str);
let subtracted_result_limb = subtracted_result.as_limbs();

This could potentially be done in some other way like using Limb::u32 or basically just not using the hex value, however, I am not too sure how to achieve that yet so if you have any suggestions do let me know.
After getting 2^256 - MODULUS, and assuming this fits into a single limb, we can probably just get the product of w4 * (2^256 - MODULUS)

let (carry1, carry2) = Limb::ZERO.mac(w4, subtracted_result_limb.[0], Limb::ZERO); 

which should fit within 2 limbs carry1 and carry2. And to do a modular addition with the non-carries we can

let carries = U256::new([Limb::ZERO, Limb::ZERO, carry2, carry1]);
let non_carries = U256::new([w0, w1, w2, w3]);
add(carries, non_carries)

So the issue here I guess is still how does the value (2^256 - MODULUS) fit within a single limb. Also, even if it fits within a single limb for u64, will it fit within a single limb for u32? Perhaps, my calculations were wrong and I really hope you could shed more light on this and provide some examples of how you would go about coding this? Thanks for you help so far! 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the issue here I guess is still how does the value (2^256 - MODULUS) fit within a single limb.

Oh, good call, this was my mistake, I thought 2^256 - MODULUS was smaller than it actually is. It looks like actually 2^256 - MODULUS needs 7 limbs (of 32-bits) to fit. For the 32-bit case, this will work out pretty well; for the 64-bit case, it's going to be more awkward.

For the 32-bit case, we can compute the carry component using the existing multiplication algorithm but ignoring the final carry, because 1 limb times 7 limbs gives us at most 8 limbs. It might be necessary to split off a mul_inner that we use for this purpose, but fundamentally, we can reuse the existing limbed multiply without needing to worry about the uppermost carry, then do a modular 256-bit addition to combine this with the non-carry result.

For the 64-bit case it's messier. I suppose you can do the same approach as the 32-bit case, but with two rounds of the mul_inner -- the very first time through you have an overflow carry of w4 that might take the full 64 bits; then do the approach of the 32-bit case, but now keep track of the w4 carry because it might have up to 32 bits in it; now do the approach of the 32-bit case again, but because the incoming carry can't be more than 32 bits and the 2^256 - MODULUS fits in 224 bits, their product will fit in 256 bits and the w4 could be dropped. This will need some careful thought for how to combine all three parts (uncarried, partially carried, fully carried), but I think it should all be possible to work out. Definitely double-check my logic here though, this 64-bit case is messy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for the implementation, 32 bit case it is pretty straight forward since the defined mul_inner will just return a reduced carry that fits in at most 8 limbs using the same multiplication logic we have

fn mul_inner(a: U256, b: Limb) -> U256 {
    let a_limbs = a.as_limbs();
    let (w0, carry) = Limb::ZERO.mac(a_limbs[0], b, Limb::ZERO);
    let (w1, carry) = Limb::ZERO.mac(a_limbs[1], b, carry);
    let (w2, carry) = Limb::ZERO.mac(a_limbs[2], b, carry);
    let (w3, cary) = Limb::ZERO.mac(a_limbs[3], b, carry);
    let (w4, carry) = Limb::ZERO.mac(a_limbs[4], b, carry);
    let (w5, carry) = Limb::ZERO.mac(a_limbs[5], b, carry);
    let (w6, carry) = Limb::ZERO.mac(a_limbs[6], b, carry);
    // We can ignore the last carry
    let (w7, _) = Limb::ZERO.mac(a_limbs[7], b, carry);

    U256::new([w0, w1, w2, w3, w4, w5, w6, w7])
}

and we call mul_inner to get a reduced carry from our initial mul_single function

// Define 2^256 - MODULUS
let subtracted_result_str: &str =
    "00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001";

let subtracted_result = U256::from_be_hex(subtracted_result_str);
// Calculate w8 << 2^256 = w8 * (w^256 - MODULUS)
let reduced_carry = mul_inner(subtracted_result, w8);

// Modular addition of non-carry and reduced carry
let non_carries = U256::new([w0, w1, w2, w3, w4, w5, w6, w7]);
add(non_carries, reduced_carry)

For the 64 bit case mul_inner is defined differently. Once we get the initial carry of w4 (which could take up to 64 bits) in mul_single, we calculate w4 * (2^256 - MODULUS) but this operation could result in an overflow of up to 32 bits hence in mul_inner we need to track this overflow which is again defined as w4. Using this overflow of at most 32 bits, we calculate w4 * (2^256 - MODULUS) again but this time the result should fit in 4 limbs [c0..c3] since 32 + 224 = 256 bits at most, hence ignoring the last carry. Finally we perform modular addition the "inner" carries and non_carries to get the reduced_carry needed for the initial mul_single. Therefore this was how I defined 64 bit mul_inner

fn mul_inner(a: U256, b: Limb) -> U256 {
    let a_limbs = a.as_limbs();
    let (w0, carry) = Limb::ZERO.mac(a_limbs[0], b, Limb::ZERO);
    let (w1, carry) = Limb::ZERO.mac(a_limbs[1], b, carry);
    let (w2, carry) = Limb::ZERO.mac(a_limbs[2], b, carry);
    let (w3, w4) = Limb::ZERO.mac(a_limbs[3], b, carry);
    let non_carries = U256::new([w0, w1, w2, w3]);

    let (c0, carry) = Limb::ZERO.mac(a_limbs[0], w4, Limb::ZERO);
    let (c1, carry) = Limb::ZERO.mac(a_limbs[1], w4, carry);
    let (c2, carry) = Limb::ZERO.mac(a_limbs[2], w4, carry);
    let (c3, _) = Limb::ZERO.mac(a_limbs[3], w4, carry);
    let reduced_carry = U256::new([c0, c1, c2, c3]);

    add(non_carries, reduced_carry)
}

mul_single in 64 bits is defined similarly to that of 32 bits. Mathematically, the logic and the bits look sound to me! Although I am not too certain the part of combining uncarried, partially carried, fully carried that you mentioned, I think what it corresponds to here is:

  • uncarried = non_carries in initial mul_single
  • partially carried = non_carries in mul_inner
  • fully carried = reduced_carries in mul_inner

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed response, I had an inopportunely timed case of covid. Doing better now though!

The 32-bit case makes sense to me. I don't think the 64-bit case is right though -- it looks like you are multiplying the carry portion by a a second time, and I think you need to be multiplying it by 0x00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001.

For the 3 parts to add back together, I think they are non_carries, reduced_carry, and _ from mul_inner, except _ is a carry limb and needs to be multiplied by 0x00000000FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001 again to reduce it into 256 bits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants