diff --git a/reference/tests.py b/reference/tests.py index 82714c9..855f2c5 100644 --- a/reference/tests.py +++ b/reference/tests.py @@ -36,13 +36,13 @@ def simulate_simplpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: pmsgs = [ret[1] for ret in prets] cmsg, cout, ceq = simplpedpop.coordinator_step(pmsgs, t, n) - pre_finalize_outputs = [(cout, ceq)] + pre_finalize_rets = [(cout, ceq)] for i in range(n): shares_sum = Scalar.sum(*([pret[2][i] for pret in prets])) - pre_finalize_outputs += [ + pre_finalize_rets += [ simplpedpop.participant_pre_finalize(prets[i][0], cmsg, shares_sum) ] - return pre_finalize_outputs + return pre_finalize_rets def encpedpop_keys(seed: bytes) -> Tuple[bytes, bytes]: @@ -67,12 +67,12 @@ def simulate_encpedpop(seeds, t) -> List[Tuple[simplpedpop.DKGOutput, bytes]]: pstates = [pstate for (pstate, _) in enc_prets1] cmsg, cout, ceq, enc_shares_sums = encpedpop.coordinator_step(pmsgs, t, enckeys) - pre_finalize_outputs = [(cout, ceq)] + pre_finalize_rets = [(cout, ceq)] for i in range(n): - pre_finalize_outputs += [ + pre_finalize_rets += [ encpedpop.participant_pre_finalize(pstates[i], cmsg, enc_shares_sums[i]) ] - return pre_finalize_outputs + return pre_finalize_rets def simulate_chilldkg( @@ -91,13 +91,13 @@ def simulate_chilldkg( for i in range(n): prets1 += [chilldkg.participant_step1(seeds[i], params)] - pstate1s = [pret[0] for pret in prets1] + pstates1 = [pret[0] for pret in prets1] pmsgs = [pret[1] for pret in prets1] cstate, cmsg = chilldkg.coordinator_step(pmsgs, params) prets2 = [] for i in range(n): - prets2 += [chilldkg.participant_step2(seeds[i], pstate1s[i], cmsg)] + prets2 += [chilldkg.participant_step2(seeds[i], pstates1[i], cmsg)] cmsg2, cout, crec = chilldkg.coordinator_finalize( cstate, [pret[1] for pret in prets2] @@ -136,15 +136,7 @@ async def main(): return await asyncio.gather(*coroutines) outputs = asyncio.run(main()) - - # Check coordinator output - return [ - ( - simplpedpop.DKGOutput(out[0][0], out[0][1], out[0][2]), - chilldkg.RecoveryData(out[1]), - ) - for out in outputs - ] + return outputs def derive_interpolating_value(L, x_i): @@ -197,7 +189,6 @@ def test_correctness_dkg_output(t, n, dkg_outputs: List[simplpedpop.DKGOutput]): assert secshares[0] is None # Check that each secshare matches the corresponding pubshare - # (secshares[1:]) for i in range(1, n + 1): assert secshares[i] * G == pubshares[0][i - 1]