Skip to content

Commit

Permalink
#519 improve set integration bench (#527)
Browse files Browse the repository at this point in the history
* #519 improve set integration bench

closes #519 issue

* fix + improved perf

* fix error name

---------

Co-authored-by: lanaivina <[email protected]>
  • Loading branch information
StringNick and lana-shanghai authored Jul 1, 2024
1 parent 28e4fdc commit 36e27d5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 50 deletions.
4 changes: 2 additions & 2 deletions build.zig.zon
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
.hash = "1220ab73fb7cc11b2308edc3364988e05efcddbcac31b707f55e6216d1b9c0da13f1",
},
.starknet = .{
.url = "https://github.com/StringNick/starknet-zig/archive/8cfb4286ffda4ad2781647c3d96b2aec8ccfeb32.zip",
.hash = "122026eaa24834fd2e2cc7e8b6c4eefb03dda08158a2844615f189758fa24d32fc44",
.url = "https://github.com/StringNick/starknet-zig/archive/57810b7a64364f1bf12725ba823385c2a213bfa5.zip",
.hash = "1220d848be799ff21a80c6751c088ea619891ec450f20017cc7aa5cbbeb5904ae8b8",
},
},
}
17 changes: 15 additions & 2 deletions src/hint_processor/set.zig
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const HintProcessor = @import("hint_processor_def.zig").CairoVMHintProcessor;
const HintData = @import("hint_processor_def.zig").HintData;
const Relocatable = @import("../vm/memory/relocatable.zig").Relocatable;
const MaybeRelocatable = @import("../vm/memory/relocatable.zig").MaybeRelocatable;
const MemoryCell = @import("../vm/memory/memory.zig").MemoryCell;
const Felt252 = @import("../math/fields/starknet.zig").Felt252;
const hint_codes = @import("builtin_hint_codes.zig");
const MathError = @import("../vm/error.zig").MathError;
Expand Down Expand Up @@ -60,10 +61,22 @@ pub fn setAdd(
// Calculate the range limit.
const range_limit = (try set_end_ptr.sub(set_ptr)).offset;

// load all list, and then we compare elements
var elm_segment = vm.segments.memory.getSegmentAtIndex(elm_ptr.segment_index) orelse return HintError.InvalidSetRange;

if (elm_segment.len < elm_ptr.offset + elm_size) return HintError.InvalidSetRange;

var set_segment = vm.segments.memory.getSegmentAtIndex(set_ptr.segment_index) orelse return HintError.InvalidSetRange;

if (set_ptr.offset + range_limit > set_segment.len) return HintError.InvalidSetRange;

elm_segment = elm_segment[elm_ptr.offset .. elm_ptr.offset + elm_size];
set_segment = set_segment[set_ptr.offset .. set_ptr.offset + range_limit];

// Iterate over the set elements.
for (0..range_limit) |i| {
for (0..range_limit / elm_size) |i| {
// Check if the element is in the set.
if (try vm.memEq(elm_ptr, try set_ptr.addUint(elm_size * i), elm_size)) {
if (MemoryCell.eqlSlice(elm_segment, set_segment[i * elm_size .. (i + 1) * elm_size])) {
// Insert index of the element into the virtual machine.
try hint_utils.insertValueFromVarName(
allocator,
Expand Down
4 changes: 2 additions & 2 deletions src/vm/core_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3666,7 +3666,7 @@ test "CairoVM: runInstruction without any insertion in the memory" {
// Compare each cell in VM's memory with the corresponding cell in the expected memory.
for (vm.segments.memory.data.items, 0..) |d, i| {
for (d.items, 0..) |cell, j| {
try expect(cell.eql(expected_memory.data.items[i].items[j]));
try expect(cell.eql(&expected_memory.data.items[i].items[j]));
}
}
}
Expand Down Expand Up @@ -3839,7 +3839,7 @@ test "CairoVM: runInstruction with Op0 being deduced" {
// Compare each cell in VM's memory with the corresponding cell in the expected memory.
for (vm.segments.memory.data.items, 0..) |d, i| {
for (d.items, 0..) |cell, j| {
try expect(cell.eql(expected_memory.data.items[i].items[j]));
try expect(cell.eql(&expected_memory.data.items[i].items[j]));
}
}
}
Expand Down
101 changes: 58 additions & 43 deletions src/vm/memory/memory.zig
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const RangeCheckBuiltinRunner = @import("../builtins/builtin_runner/range_check.
// Function that validates a memory address and returns a list of validated adresses
pub const validation_rule = *const fn (Allocator, *Memory, Relocatable) anyerror!std.ArrayList(Relocatable);

pub const MemoryCell = struct {
pub const MemoryCell = extern struct {
/// Represents a memory cell that holds relocation information and access status.
const Self = @This();
const ACCESS_MASK: u64 = 1 << 62;
Expand Down Expand Up @@ -103,8 +103,12 @@ pub const MemoryCell = struct {
/// # Returns
///
/// Returns `true` if both MemoryCell instances are equal, otherwise `false`.
pub fn eql(self: Self, other: Self) bool {
return std.mem.eql(u64, self.data[0..], other.data[0..]);
pub fn eql(self: *const Self, other: *const Self) bool {
inline for (0..4) |i| {
if (self.data[i] != other.data[i]) return false;
}

return true;
}

/// Checks equality between slices of MemoryCell instances.
Expand All @@ -124,7 +128,7 @@ pub const MemoryCell = struct {
if (a.len != b.len) return false;
if (a.ptr == b.ptr) return true;

for (a, b) |a_elem, b_elem| {
for (a, b) |*a_elem, *b_elem| {
if (!a_elem.eql(b_elem)) return false;
}

Expand Down Expand Up @@ -609,20 +613,11 @@ pub const Memory = struct {
/// # Returns
///
/// Returns the segment of MemoryCell items if it exists, or `null` if not found.
fn getSegmentAtIndex(self: *Self, idx: i64) ?[]MemoryCell {
return switch (idx < 0) {
true => blk: {
const i: usize = @intCast(-(idx + 1));
break :blk if (i < self.temp_data.items.len)
self.temp_data.items[i].items
else
null;
},
false => if (idx < self.data.items.len)
self.data.items[@intCast(idx)].items
else
null,
};
pub inline fn getSegmentAtIndex(self: *const Self, idx: i64) ?[]MemoryCell {
return if (idx < 0) {
const i: usize = @bitCast(-(idx + 1));
return if (i >= self.temp_data.items.len) null else self.temp_data.items[i].items;
} else if (idx >= self.data.items.len) null else self.data.items[@intCast(idx)].items;
}

/// Compares two memory segments within the VM's memory starting from specified addresses
Expand Down Expand Up @@ -663,12 +658,6 @@ pub const Memory = struct {
const l_idx = lhs.offset + i;
const r_idx = rhs.offset + i;

// std.log.err("lhs: {any}, rhs: {any}, i: {any}, {any}", .{
// if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE, if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE, i, MemoryCell.cmp(
// if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE,
// if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE,
// ),
// });
return switch (MemoryCell.cmp(
if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE,
if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE,
Expand Down Expand Up @@ -700,7 +689,7 @@ pub const Memory = struct {
/// # Returns
///
/// Returns `true` if segments are equal up to the specified length, otherwise `false`.
pub fn memEq(self: *Self, lhs: Relocatable, rhs: Relocatable, len: usize) !bool {
pub fn memEq(self: *const Self, lhs: Relocatable, rhs: Relocatable, len: usize) !bool {
// Check if the left and right addresses are the same, in which case the segments are equal.
if (lhs.eq(rhs)) return true;

Expand All @@ -714,29 +703,25 @@ pub const Memory = struct {
// Get the segment starting from the right-hand address.
const r: ?[]MemoryCell = if (self.getSegmentAtIndex(rhs.segment_index)) |s|
// Check if the offset is within the bounds of the segment.
if (rhs.offset < s.len) s[rhs.offset..] else if (l == null) return true else return false
else if (l == null) return true else return false;
if (rhs.offset < s.len) s[rhs.offset..] else return l == null
else
return l == null;

// If the left segment exists, perform further checks.
if (l) |ls| {
// If the right segment also exists, compare the segments up to the specified length.
if (r) |rs| {
// Determine the actual lengths to compare.
const lhs_len = @min(ls.len, len);
const rhs_len = @min(rs.len, len);
// Determine the actual lengths to compare.
const lhs_len = @min(ls.len, len);
const rhs_len = @min(r.?.len, len);

// Compare slices of MemoryCell items up to the specified length.
if (lhs_len != rhs_len) return false;

return MemoryCell.eqlSlice(ls[0..lhs_len], rs[0..rhs_len]);
}
// Compare slices of MemoryCell items up to the specified length.
if (lhs_len != rhs_len) return false;

// If only the left segment exists, return false.
return false;
return MemoryCell.eqlSlice(ls[0..lhs_len], r.?[0..rhs_len]);
}

// If the left segment does not exist, return true only if the right segment is also null.
return r == null;
// If only the left segment exists, return false.
return false;
}

/// Retrieves a range of memory values starting from a specified address.
Expand Down Expand Up @@ -769,6 +754,36 @@ pub const Memory = struct {
return values;
}

/// Retrieves a range of memory values starting from a specified address.
///
/// # Arguments
///
/// * `allocator`: The allocator used for the memory allocation of the returned list.
/// * `address`: The starting address in the memory from which the range is retrieved.
/// * `size`: The size of the range to be retrieved.
///
/// # Returns
///
/// Returns a list containing memory values retrieved from the specified range starting at the given address.
/// The list may contain `MemoryCell.NONE` elements for inaccessible memory positions.
///
/// # Errors
///
/// Returns an error if there are any issues encountered during the retrieval of the memory range.
pub fn getRangeRaw(
self: *Self,
allocator: Allocator,
address: Relocatable,
size: usize,
) !std.ArrayList(?MaybeRelocatable) {
var values = std.ArrayList(?MaybeRelocatable).init(allocator);
errdefer values.deinit();
for (0..size) |i| {
try values.append(self.get(try address.addUint(i)));
}
return values;
}

/// Counts the number of accessed addresses within a specified segment in the VM memory.
///
/// # Arguments
Expand Down Expand Up @@ -2426,9 +2441,9 @@ test "MemoryCell: eql function" {
memoryCell4.markAccessed();

// Test checks
try expect(memoryCell1.eql(memoryCell2));
try expect(!memoryCell1.eql(memoryCell3));
try expect(!memoryCell1.eql(memoryCell4));
try expect(memoryCell1.eql(&memoryCell2));
try expect(!memoryCell1.eql(&memoryCell3));
try expect(!memoryCell1.eql(&memoryCell4));
}

test "MemoryCell: eqlSlice should return false if slice len are not the same" {
Expand Down
11 changes: 10 additions & 1 deletion src/vm/memory/relocatable.zig
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,16 @@ pub const MaybeRelocatable = union(enum) {
/// * `true` if the two instances are equal.
/// * `false` otherwise.
pub fn eq(self: Self, other: Self) bool {
return std.meta.eql(self, other);
return switch (self) {
inline .felt => |f| switch (other) {
inline .felt => |f1| f.eql(f1),
else => false,
},
inline .relocatable => |r| switch (other) {
inline .relocatable => |r1| r.eq(r1),
else => false,
},
};
}

/// Determines if self is less than other.
Expand Down

0 comments on commit 36e27d5

Please sign in to comment.