Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #10380 - Memoize should handle lambdas #7507

Merged
merged 5 commits into from
Dec 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 157 additions & 60 deletions std/functional.d
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,15 @@ alias pipe(fun...) = compose!(Reverse!(fun));
assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5);
}

private template getOverloads(alias fun)
{
import std.meta : AliasSeq;
static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun), true)))
alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun), true);
else
alias getOverloads = AliasSeq!fun;
}

/**
* $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as
* to avoid repeated computation. The memoization structure is a hash table keyed by a
Expand Down Expand Up @@ -1324,87 +1333,131 @@ Note:
*/
template memoize(alias fun)
{
import std.traits : ReturnType;
// https://issues.dlang.org/show_bug.cgi?id=13580
// alias Args = Parameters!fun;
import std.traits : Parameters;
import std.meta : anySatisfy;

// Specific overloads:
alias overloads = getOverloads!fun;
static foreach (fn; overloads)
static if (is(Parameters!fn))
alias memoize = impl!(Parameters!fn);

enum isTemplate(alias a) = __traits(isTemplate, a);
static if (anySatisfy!(isTemplate, overloads))
{
// Generic implementation
alias memoize = impl;
}

ReturnType!fun memoize(Parameters!fun args)
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
{
alias Args = Parameters!fun;
import std.typecons : Tuple;
import std.typecons : Tuple, tuple;
import std.traits : Unqual;

static Unqual!(ReturnType!fun)[Tuple!Args] memo;
auto t = Tuple!Args(args);
if (auto p = t in memo)
return *p;
auto r = fun(args);
memo[t] = r;
return r;
static if (args.length > 0)
{
static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo;

auto t = Tuple!Args(args);
if (auto p = t in memo)
return *p;
auto r = fun(args);
memo[t] = r;
return r;
}
else
{
static typeof(fun(args)) result;
result = fun(args);
return result;
}
}
}

/// ditto
template memoize(alias fun, uint maxSize)
{
import std.traits : ReturnType;
// https://issues.dlang.org/show_bug.cgi?id=13580
// alias Args = Parameters!fun;
ReturnType!fun memoize(Parameters!fun args)
import std.traits : Parameters;
import std.meta : anySatisfy;

// Specific overloads:
alias overloads = getOverloads!fun;
static foreach (fn; overloads)
static if (is(Parameters!fn))
alias memoize = impl!(Parameters!fn);

enum isTemplate(alias a) = __traits(isTemplate, a);
static if (anySatisfy!(isTemplate, overloads))
{
import std.meta : staticMap;
import std.traits : hasIndirections, Unqual;
import std.typecons : tuple;
static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; }
static Value[] memo;
static size_t[] initialized;
// Generic implementation
alias memoize = impl;
}

if (!memo.length)
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
{
static if (args.length > 0)
{
import core.memory : GC;
import std.meta : staticMap;
import std.traits : hasIndirections, Unqual;
import std.typecons : tuple;
alias returnType = typeof(fun(args));
static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; }
static Value[] memo;
static size_t[] initialized;

// Ensure no allocation overflows
static assert(maxSize < size_t.max / Value.sizeof);
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
if (!memo.length)
{
import core.memory : GC;

enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
}
// Ensure no allocation overflows
static assert(maxSize < size_t.max / Value.sizeof);
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));

import core.bitop : bt, bts;
import core.lifetime : emplace;
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
}

size_t hash;
foreach (ref arg; args)
hash = hashOf(arg, hash);
// cuckoo hashing
immutable idx1 = hash % maxSize;
if (!bt(initialized.ptr, idx1))
{
emplace(&memo[idx1], args, fun(args));
// only set to initialized after setting args and value
// https://issues.dlang.org/show_bug.cgi?id=14025
bts(initialized.ptr, idx1);
import core.bitop : bt, bts;
import core.lifetime : emplace;

size_t hash;
foreach (ref arg; args)
hash = hashOf(arg, hash);
// cuckoo hashing
immutable idx1 = hash % maxSize;
if (!bt(initialized.ptr, idx1))
{
emplace(&memo[idx1], args, fun(args));
// only set to initialized after setting args and value
// https://issues.dlang.org/show_bug.cgi?id=14025
bts(initialized.ptr, idx1);
return memo[idx1].res;
}
else if (memo[idx1].args == args)
return memo[idx1].res;
// FNV prime
immutable idx2 = (hash * 16_777_619) % maxSize;
if (!bt(initialized.ptr, idx2))
{
emplace(&memo[idx2], memo[idx1]);
bts(initialized.ptr, idx2);
}
else if (memo[idx2].args == args)
return memo[idx2].res;
else if (idx1 != idx2)
memo[idx2] = memo[idx1];

memo[idx1] = Value(args, fun(args));
return memo[idx1].res;
}
else if (memo[idx1].args == args)
return memo[idx1].res;
// FNV prime
immutable idx2 = (hash * 16_777_619) % maxSize;
if (!bt(initialized.ptr, idx2))
else
{
emplace(&memo[idx2], memo[idx1]);
bts(initialized.ptr, idx2);
static typeof(fun(args)) result;
result = fun(args);
return result;
}
else if (memo[idx2].args == args)
return memo[idx2].res;
else if (idx1 != idx2)
memo[idx2] = memo[idx1];

memo[idx1] = Value(args, fun(args));
return memo[idx1].res;
}
}

Expand Down Expand Up @@ -1464,6 +1517,37 @@ unittest
assert(fact(10) == 3628800);
}

// Issue 20099
@system unittest // not @safe due to memoize
{
int i = 3;
alias a = memoize!((n) => i + n);
Biotronic marked this conversation as resolved.
Show resolved Hide resolved
alias b = memoize!((n) => i + n, 3);

assert(a(3) == 6);
assert(b(3) == 6);
}

@system unittest // not @safe due to memoize
{
static Object objNum(int a) { return new Object(); }
Biotronic marked this conversation as resolved.
Show resolved Hide resolved
assert(memoize!objNum(0) is memoize!objNum(0U));
assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U));
}

@system unittest // not @safe due to memoize
{
struct S
{
static int fun() { return 0; }
static int fun(int i) { return 1; }
}
assert(memoize!(S.fun)() == 0);
assert(memoize!(S.fun)(3) == 1);
assert(memoize!(S.fun, 3)() == 0);
assert(memoize!(S.fun, 3)(3) == 1);
}

@system unittest // not @safe due to memoize
{
import core.math : sqrt;
Expand Down Expand Up @@ -1626,6 +1710,19 @@ unittest
}}
}

// memoize should continue to work with functions that cannot be evaluated at compile time
@system unittest
{
__gshared string[string] glob;

static bool foo()
{
return (":-)" in glob) is null;
}

assert(memoize!foo);
}

private struct DelegateFaker(F)
{
import std.typecons : FuncInfo, MemberFunctionGenerator;
Expand Down
Loading