diff --git a/synedrion/src/cggmp21/key_refresh.rs b/synedrion/src/cggmp21/key_refresh.rs index fe0cff3..66c9b8b 100644 --- a/synedrion/src/cggmp21/key_refresh.rs +++ b/synedrion/src/cggmp21/key_refresh.rs @@ -32,7 +32,7 @@ use crate::{ tools::{ bitvec::BitVec, hashing::{Chain, FofHasher, HashOutput, XofHasher}, - DowncastMap, GetRound, SafeGet, Secret, Without, + verify_that, DeserializeAll, DowncastMap, GetRound, SafeGet, Secret, Without, }, }; @@ -51,9 +51,9 @@ impl Protocol for KeyRefreshProtocol { message: &DirectMessage, ) -> Result<(), MessageValidationError> { match round_id { - r if r == &RoundId::new(1) => message.verify_is_some(), - r if r == &RoundId::new(2) => message.verify_is_some(), - r if r == &RoundId::new(3) => message.verify_is_not::>(deserializer), + r if r == &1 => message.verify_is_some(), + r if r == &2 => message.verify_is_some(), + r if r == &3 => message.verify_is_not::>(deserializer), _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), } } @@ -64,9 +64,9 @@ impl Protocol for KeyRefreshProtocol { message: &EchoBroadcast, ) -> Result<(), MessageValidationError> { match round_id { - r if r == &RoundId::new(1) => message.verify_is_not::(deserializer), - r if r == &RoundId::new(2) => message.verify_is_some(), - r if r == &RoundId::new(3) => message.verify_is_not::>(deserializer), + r if r == &1 => message.verify_is_not::(deserializer), + r if r == &2 => message.verify_is_some(), + r if r == &3 => message.verify_is_not::>(deserializer), _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), } } @@ -77,9 +77,9 @@ impl Protocol for KeyRefreshProtocol { message: &NormalBroadcast, ) -> Result<(), MessageValidationError> { match round_id { - r if r == &RoundId::new(1) => message.verify_is_some(), - r if r == &RoundId::new(2) => message.verify_is_not::>(deserializer), - r if r == &RoundId::new(3) => message.verify_is_not::>(deserializer), + r if r == &1 => message.verify_is_some(), + r if r == &2 => message.verify_is_not::>(deserializer), + r if r == &3 => message.verify_is_not::>(deserializer), _ => Err(MessageValidationError::InvalidEvidence("Invalid round number".into())), } } @@ -132,15 +132,47 @@ enum KeyRefreshErrorEnum { /// Round2: P_prm verification failed R2PrmFailed, /// Round3: secret share change does not match the public commitment - R3ShareChangeMismatch, + R3ShareChangeMismatch { + /// The index $i$ of the node that produced the evidence. + reported_by: I, + /// $y_{i,j}$, where where $j$ is the index of the guilty party. + y: Scalar, + }, /// Round3: P_mod verification failed R3ModFailed, /// Round3: P_fac verification failed - R3FacFailed, + R3FacFailed { + /// The index $i$ of the node that produced the evidence. + reported_by: I, + }, /// Round3: Wrong IDs in Schnorr proofs map R3WrongIdsHatPsi, /// Round3: P_sch verification failed - R3SchFailed(I), + R3SchFailed { + /// The index $k$ for which the verification of $П^{sch}_{j,k}$ failed + /// (where $j$ is the index of the guilty party). + failed_for: I, + }, +} + +/// Reconstruct `rid` from echoed messages +fn reconstruct_rid( + deserializer: &Deserializer, + previous_messages: &BTreeMap, + combined_echos: &BTreeMap>, +) -> Result { + let r2_messages = combined_echos + .get_round(2)? + .deserialize_all::>(deserializer)?; + let r2_echo = previous_messages + .get_round(2)? + .echo_broadcast + .deserialize::>(deserializer)?; + let mut rid = r2_echo.rid_part; + for message in r2_messages.values() { + rid ^= &message.rid_part; + } + Ok(rid) } impl ProtocolError for KeyRefreshError { @@ -150,13 +182,51 @@ impl ProtocolError for KeyRefreshError { match self.error { KeyRefreshErrorEnum::R2HashMismatch => RequiredMessages::new( RequiredMessageParts::normal_broadcast_only(), - Some([(RoundId::new(1), RequiredMessageParts::echo_broadcast_only())].into()), + Some([(1.into(), RequiredMessageParts::echo_broadcast_only())].into()), None, ), KeyRefreshErrorEnum::R2WrongIdsX => { RequiredMessages::new(RequiredMessageParts::normal_broadcast_only(), None, None) } - _ => unimplemented!(), + KeyRefreshErrorEnum::R2WrongIdsY => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast_only(), None, None) + } + KeyRefreshErrorEnum::R2WrongIdsA => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast_only(), None, None) + } + KeyRefreshErrorEnum::R2PaillierModulusTooSmall => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast_only(), None, None) + } + KeyRefreshErrorEnum::R2RPModulusTooSmall => { + RequiredMessages::new(RequiredMessageParts::echo_broadcast_only(), None, None) + } + KeyRefreshErrorEnum::R2NonZeroSumOfChanges => { + RequiredMessages::new(RequiredMessageParts::normal_broadcast_only(), None, None) + } + KeyRefreshErrorEnum::R2PrmFailed => RequiredMessages::new(RequiredMessageParts::all(), None, None), + KeyRefreshErrorEnum::R3ShareChangeMismatch { .. } => RequiredMessages::new( + RequiredMessageParts::direct_message_only(), + Some([(2.into(), RequiredMessageParts::all())].into()), + Some([2.into()].into()), + ), + KeyRefreshErrorEnum::R3ModFailed => RequiredMessages::new( + RequiredMessageParts::normal_broadcast_only(), + Some([(2.into(), RequiredMessageParts::all())].into()), + Some([2.into()].into()), + ), + KeyRefreshErrorEnum::R3FacFailed { .. } => RequiredMessages::new( + RequiredMessageParts::direct_message_only(), + Some([(2.into(), RequiredMessageParts::all())].into()), + Some([2.into()].into()), + ), + KeyRefreshErrorEnum::R3WrongIdsHatPsi => { + RequiredMessages::new(RequiredMessageParts::echo_broadcast_only(), None, None) + } + KeyRefreshErrorEnum::R3SchFailed { .. } => RequiredMessages::new( + RequiredMessageParts::echo_broadcast_only(), + Some([(2.into(), RequiredMessageParts::all())].into()), + Some([2.into()].into()), + ), } } @@ -168,14 +238,14 @@ impl ProtocolError for KeyRefreshError { associated_data: &Self::AssociatedData, message: ProtocolMessage, previous_messages: BTreeMap, - _combined_echos: BTreeMap>, + combined_echos: BTreeMap>, ) -> Result<(), ProtocolValidationError> { let sid_hash = FofHasher::new_with_dst(b"SID") .chain_type::

() .chain(&shared_randomness) .finalize(); - match self.error { + match &self.error { KeyRefreshErrorEnum::R2HashMismatch => { let r1_message = previous_messages .get_round(1)? @@ -184,25 +254,156 @@ impl ProtocolError for KeyRefreshError { let r2_message = message .normal_broadcast .deserialize::>(deserializer)?; - if r2_message.hash(&sid_hash, guilty_party) != r1_message.cap_v { - Ok(()) - } else { - Err(ProtocolValidationError::InvalidEvidence( - "The received hash is valid".into(), - )) - } + verify_that(r2_message.hash(&sid_hash, guilty_party) != r1_message.cap_v) } KeyRefreshErrorEnum::R2WrongIdsX => { let r2_message = message .normal_broadcast .deserialize::>(deserializer)?; - if &r2_message.cap_xs.keys().cloned().collect::>() != associated_data { - Ok(()) - } else { - Err(ProtocolValidationError::InvalidEvidence("The IDs are correct".into())) + verify_that(&r2_message.cap_xs.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R2WrongIdsY => { + let r2_message = message + .echo_broadcast + .deserialize::>(deserializer)?; + verify_that(&r2_message.cap_ys.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R2WrongIdsA => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that(&r2_message.cap_as.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R2PaillierModulusTooSmall => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that( + r2_message.paillier_pk.modulus().bits_vartime() < ::MODULUS_BITS - 2, + ) + } + KeyRefreshErrorEnum::R2RPModulusTooSmall => { + let r2_message = message + .echo_broadcast + .deserialize::>(deserializer)?; + verify_that( + r2_message.rp_params.modulus().bits_vartime() < ::MODULUS_BITS - 2, + ) + } + KeyRefreshErrorEnum::R2NonZeroSumOfChanges => { + let r2_message = message + .normal_broadcast + .deserialize::>(deserializer)?; + verify_that(r2_message.cap_xs.values().sum::() != Point::IDENTITY) + } + KeyRefreshErrorEnum::R2PrmFailed => { + let r2_eb = message + .echo_broadcast + .deserialize::>(deserializer)?; + let r2_bc = message + .normal_broadcast + .deserialize::>(deserializer)?; + let aux = (&sid_hash, guilty_party); + let rp_params = r2_eb.rp_params.to_precomputed(); + verify_that(!r2_bc.psi.verify(&rp_params, &aux)) + } + KeyRefreshErrorEnum::R3ShareChangeMismatch { reported_by, y } => { + // Check that `y` attached to the evidence is correct + // (that is, can be verified against something signed by `guilty_party`). + // It is `y_{i,j}` where `i == reported_by` and `j == guilty_party` + let r2_message_i = combined_echos + .get_round(2)? + .try_get("combined echos for Round 2", reported_by)? + .deserialize::>(deserializer)?; + let cap_y_ij = r2_message_i.cap_ys.try_get("public Elgamal values", guilty_party)?; + if &y.mul_by_generator() != cap_y_ij { + return Err(ProtocolValidationError::InvalidEvidence( + "The provided `y` is invalid".into(), + )); } + + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + + let r2_echo = previous_messages + .get_round(2)? + .echo_broadcast + .deserialize::>(deserializer)?; + let cap_y_ji = r2_echo.cap_ys.try_get("public Elgamal values", reported_by)?; + let mut reader = XofHasher::new_with_dst(b"KeyRefresh Round3") + .chain(&sid_hash) + .chain(&rid) + .chain(guilty_party) + .chain(&(cap_y_ji * y)) + .finalize_to_reader(); + let rho = Scalar::from_xof_reader(&mut reader); + + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_message = message + .direct_message + .deserialize::>(deserializer)?; + + let x = r3_message.cap_c - rho; + let cap_x_ji = r2_bc.cap_xs.try_get("public key share changes", reported_by)?; + verify_that(&x.mul_by_generator() != cap_x_ji) + } + KeyRefreshErrorEnum::R3ModFailed => { + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + let aux = (&sid_hash, guilty_party, &rid); + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_bc = message + .normal_broadcast + .deserialize::>(deserializer)?; + let paillier_pk = r2_bc.paillier_pk.into_precomputed(); + verify_that(!r3_bc.psi_prime.verify(&paillier_pk, &aux)) + } + KeyRefreshErrorEnum::R3FacFailed { reported_by } => { + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + let aux = (&sid_hash, guilty_party, &rid); + + let r2_eb = combined_echos + .get_round(2)? + .try_get("combined echos for Round 2", reported_by)? + .deserialize::>(deserializer)?; + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_dm = message + .direct_message + .deserialize::>(deserializer)?; + let paillier_pk = r2_bc.paillier_pk.into_precomputed(); + let rp_params = r2_eb.rp_params.to_precomputed(); + verify_that(!r3_dm.psi.verify(&paillier_pk, &rp_params, &aux)) + } + KeyRefreshErrorEnum::R3WrongIdsHatPsi => { + let r3_eb = message + .echo_broadcast + .deserialize::>(deserializer)?; + verify_that(&r3_eb.hat_psis.keys().cloned().collect::>() != associated_data) + } + KeyRefreshErrorEnum::R3SchFailed { failed_for } => { + let rid = reconstruct_rid::(deserializer, &previous_messages, &combined_echos)?; + let aux = (&sid_hash, guilty_party, &rid); + + let r2_bc = previous_messages + .get_round(2)? + .normal_broadcast + .deserialize::>(deserializer)?; + let r3_eb = message + .echo_broadcast + .deserialize::>(deserializer)?; + + let cap_a = r2_bc.cap_as.try_get("Schnorr commitments", failed_for)?; + let cap_x = r2_bc.cap_xs.try_get("public share changes", failed_for)?; + let hat_psi = r3_eb.hat_psis.try_get("Schnorr proofs", failed_for)?; + verify_that(!hat_psi.verify(cap_a, cap_x, &aux)) } - _ => unimplemented!(), } } } @@ -296,17 +497,20 @@ impl EntryPoint for KeyRefresh { let u = BitVec::random(rng, P::SECURITY_PARAMETER); // Note: typo in the paper, $V$ hashes in $B_i$ which is not present in the '24 version of the paper. - let r2_broadcast = Round2Broadcast { + let r2_normal_broadcast = Round2Broadcast { cap_xs, - cap_ys, cap_as, paillier_pk: paillier_pk.clone(), - rp_params: rp_params.to_wire(), psi, - rid_part, u, }; + let r2_echo_broadcast = Round2EchoBroadcast { + rp_params: rp_params.to_wire(), + cap_ys, + rid_part, + }; + let context = Context { paillier_sk: paillier_sk.into_precomputed(), rp_params, @@ -319,7 +523,11 @@ impl EntryPoint for KeyRefresh { sid_hash, }; - let round = Round1 { context, r2_broadcast }; + let round = Round1 { + context, + r2_normal_broadcast, + r2_echo_broadcast, + }; Ok(BoxedRound::new_dynamic(round)) } @@ -341,7 +549,8 @@ struct Context { #[derive(Debug)] struct Round1 { context: Context, - r2_broadcast: Round2Broadcast, + r2_normal_broadcast: Round2Broadcast, + r2_echo_broadcast: Round2EchoBroadcast, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -378,7 +587,9 @@ impl Round for Round1 { serializer: &Serializer, ) -> Result { let message = Round1EchoBroadcast { - cap_v: self.r2_broadcast.hash(&self.context.sid_hash, &self.context.my_id), + cap_v: self + .r2_normal_broadcast + .hash(&self.context.sid_hash, &self.context.my_id), }; EchoBroadcast::new(serializer, message) } @@ -410,7 +621,8 @@ impl Round for Round1 { let others_cap_v = payloads.into_iter().map(|(id, payload)| (id, payload.cap_v)).collect(); let next_round = Round2 { context: self.context, - r2_broadcast: self.r2_broadcast, + r2_echo_broadcast: self.r2_echo_broadcast, + r2_normal_broadcast: self.r2_normal_broadcast, others_cap_v, }; Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(next_round))) @@ -420,7 +632,8 @@ impl Round for Round1 { #[derive(Debug)] struct Round2 { context: Context, - r2_broadcast: Round2Broadcast, + r2_normal_broadcast: Round2Broadcast, + r2_echo_broadcast: Round2EchoBroadcast, others_cap_v: BTreeMap, } @@ -434,14 +647,24 @@ struct Round2 { struct Round2Broadcast { cap_xs: BTreeMap, // $X_{i,j}$ where $i$ is this party's index cap_as: BTreeMap, // $A_{i,j}$ where $i$ is this party's index - cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index paillier_pk: PublicKeyPaillierWire, // $N_i$ - rp_params: RPParamsWire, // $\hat{N}_i$, $s_i$, and $t_i$ psi: PrmProof

, - rid_part: BitVec, u: BitVec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound(serialize = " + I: Serialize, +"))] +#[serde(bound(deserialize = " + I: for<'x> Deserialize<'x>, +"))] +struct Round2EchoBroadcast { + rp_params: RPParamsWire, // $\hat{N}_i$, $s_i$, and $t_i$ + cap_ys: BTreeMap, // $Y_{i,j}$ where $i$ is this party's index + rid_part: BitVec, +} + impl Round2Broadcast { fn hash(&self, sid_hash: &HashOutput, id: &I) -> HashOutput { FofHasher::new_with_dst(b"Auxiliary") @@ -486,7 +709,15 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, ) -> Result { - NormalBroadcast::new(serializer, self.r2_broadcast.clone()) + NormalBroadcast::new(serializer, self.r2_normal_broadcast.clone()) + } + + fn make_echo_broadcast( + &self, + _rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + EchoBroadcast::new(serializer, self.r2_echo_broadcast.clone()) } fn receive_message( @@ -495,8 +726,10 @@ impl Round for Round2 { from: &I, message: ProtocolMessage, ) -> Result> { - message.echo_broadcast.assert_is_none()?; message.direct_message.assert_is_none()?; + let echo_broadcast = message + .echo_broadcast + .deserialize::>(deserializer)?; let normal_broadcast = message .normal_broadcast .deserialize::>(deserializer)?; @@ -515,7 +748,7 @@ impl Round for Round2 { ))); } - if normal_broadcast.cap_ys.keys().cloned().collect::>() != self.context.all_ids { + if echo_broadcast.cap_ys.keys().cloned().collect::>() != self.context.all_ids { return Err(ReceiveError::protocol(KeyRefreshError::new( KeyRefreshErrorEnum::R2WrongIdsY, ))); @@ -528,7 +761,7 @@ impl Round for Round2 { } let paillier_pk = normal_broadcast.paillier_pk.clone().into_precomputed(); - let rp_params = normal_broadcast.rp_params.to_precomputed(); + let rp_params = echo_broadcast.rp_params.to_precomputed(); if paillier_pk.modulus().bits_vartime() < ::MODULUS_BITS - 2 { return Err(ReceiveError::protocol(KeyRefreshError::new( @@ -558,10 +791,10 @@ impl Round for Round2 { let payload = Round2Payload:: { cap_xs: normal_broadcast.cap_xs, cap_as: normal_broadcast.cap_as, - cap_ys: normal_broadcast.cap_ys, + cap_ys: echo_broadcast.cap_ys, paillier_pk: normal_broadcast.paillier_pk.into_precomputed(), - rp_params: normal_broadcast.rp_params.to_precomputed(), - rid_part: normal_broadcast.rid_part, + rp_params: echo_broadcast.rp_params.to_precomputed(), + rid_part: echo_broadcast.rid_part, }; Ok(Payload::new(payload)) @@ -575,19 +808,19 @@ impl Round for Round2 { ) -> Result, LocalError> { let mut payloads = payloads.downcast_all::>()?; - let mut rid = self.r2_broadcast.rid_part.clone(); + let mut rid = self.r2_echo_broadcast.rid_part.clone(); for payload in payloads.values() { rid ^= &payload.rid_part; } // Add in the payload with this node's info, for the sake of uniformity let my_payload = Round2Payload:: { - cap_xs: self.r2_broadcast.cap_xs, - cap_as: self.r2_broadcast.cap_as, - cap_ys: self.r2_broadcast.cap_ys, - paillier_pk: self.r2_broadcast.paillier_pk.into_precomputed(), - rp_params: self.r2_broadcast.rp_params.to_precomputed(), - rid_part: self.r2_broadcast.rid_part, + cap_xs: self.r2_normal_broadcast.cap_xs, + cap_as: self.r2_normal_broadcast.cap_as, + cap_ys: self.r2_echo_broadcast.cap_ys, + paillier_pk: self.r2_normal_broadcast.paillier_pk.into_precomputed(), + rp_params: self.r2_echo_broadcast.rp_params.to_precomputed(), + rid_part: self.r2_echo_broadcast.rid_part, }; payloads.insert(self.context.my_id.clone(), my_payload); @@ -787,9 +1020,11 @@ impl Round for Round3 { let x = Secret::init_with(|| direct_message.cap_c - rho); let my_cap_x = r2_payload.cap_xs.safe_get("public share changes", my_id)?; if &x.mul_by_generator() != my_cap_x { - // TODO: can we put all the necessary info in the proof? return Err(ReceiveError::protocol(KeyRefreshError::new( - KeyRefreshErrorEnum::R3ShareChangeMismatch, + KeyRefreshErrorEnum::R3ShareChangeMismatch { + reported_by: my_id.clone(), + y: *y.expose_secret(), + }, ))); } @@ -805,7 +1040,9 @@ impl Round for Round3 { .verify(&r2_payload.paillier_pk, &self.context.rp_params, &aux) { return Err(ReceiveError::protocol(KeyRefreshError::new( - KeyRefreshErrorEnum::R3FacFailed, + KeyRefreshErrorEnum::R3FacFailed { + reported_by: my_id.clone(), + }, ))); } @@ -820,7 +1057,7 @@ impl Round for Round3 { let cap_x = r2_payload.cap_xs.safe_get("Public share changes", id)?; if !hat_psi.verify(cap_a, cap_x, &aux) { return Err(ReceiveError::protocol(KeyRefreshError::new( - KeyRefreshErrorEnum::R3SchFailed(id.clone()), + KeyRefreshErrorEnum::R3SchFailed { failed_for: id.clone() }, ))); } } diff --git a/synedrion/src/paillier/keys.rs b/synedrion/src/paillier/keys.rs index 4520a93..743f819 100644 --- a/synedrion/src/paillier/keys.rs +++ b/synedrion/src/paillier/keys.rs @@ -293,6 +293,10 @@ impl PublicKeyPaillierWire

{ } } + pub fn modulus(&self) -> &P::Uint { + self.modulus.modulus() + } + pub fn into_precomputed(self) -> PublicKeyPaillier

{ PublicKeyPaillier::new(self.modulus.into_precomputed()) } diff --git a/synedrion/src/paillier/ring_pedersen.rs b/synedrion/src/paillier/ring_pedersen.rs index 5369157..ad6654e 100644 --- a/synedrion/src/paillier/ring_pedersen.rs +++ b/synedrion/src/paillier/ring_pedersen.rs @@ -152,6 +152,10 @@ pub(crate) struct RPParamsWire { } impl RPParamsWire

{ + pub fn modulus(&self) -> &P::Uint { + self.modulus.modulus() + } + pub fn to_precomputed(&self) -> RPParams

{ let modulus = self.modulus.clone().into_precomputed(); let base_randomizer = self.base_randomizer.to_montgomery(modulus.monty_params_mod_n());