Skip to content

Commit

Permalink
Adds support for scripting engines as Valkey modules (#1277)
Browse files Browse the repository at this point in the history
This PR extends the module API to support the addition of different
scripting engines to execute user defined functions.

The scripting engine can be implemented as a Valkey module, and can be
dynamically loaded with the `loadmodule` config directive, or with the
`MODULE LOAD` command.

This PR also adds an example of a dummy scripting engine module, to show
how to use the new module API. The dummy module is implemented in
`tests/modules/helloscripting.c`.

The current module API support, only allows to load scripting engines to
run functions using `FCALL` command.

The additions to the module API are the following:

```c
/* This struct represents a scripting engine function that results from the
 * compilation of a script by the engine implementation. */
struct ValkeyModuleScriptingEngineCompiledFunction

typedef ValkeyModuleScriptingEngineCompiledFunction **(*ValkeyModuleScriptingEngineCreateFunctionsLibraryFunc)(
    ValkeyModuleScriptingEngineCtx *engine_ctx,
    const char *code,
    size_t timeout,
    size_t *out_num_compiled_functions,
    char **err);

typedef void (*ValkeyModuleScriptingEngineCallFunctionFunc)(
    ValkeyModuleCtx *module_ctx,
    ValkeyModuleScriptingEngineCtx *engine_ctx,
    ValkeyModuleScriptingEngineFunctionCtx *func_ctx,
    void *compiled_function,
    ValkeyModuleString **keys,
    size_t nkeys,
    ValkeyModuleString **args,
    size_t nargs);

typedef size_t (*ValkeyModuleScriptingEngineGetUsedMemoryFunc)(
    ValkeyModuleScriptingEngineCtx *engine_ctx);

typedef size_t (*ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc)(
    void *compiled_function);

typedef size_t (*ValkeyModuleScriptingEngineGetEngineMemoryOverheadFunc)(
    ValkeyModuleScriptingEngineCtx *engine_ctx);

typedef void (*ValkeyModuleScriptingEngineFreeFunctionFunc)(
    ValkeyModuleScriptingEngineCtx *engine_ctx,
    void *compiled_function);

/* This struct stores the callback functions implemented by the scripting
 * engine to provide the functionality for the `FUNCTION *` commands. */
typedef struct ValkeyModuleScriptingEngineMethodsV1 {
    uint64_t version; /* Version of this structure for ABI compat. */

    /* Library create function callback. When a new script is loaded, this
     * callback will be called with the script code, and returns a list of
     * ValkeyModuleScriptingEngineCompiledFunc objects. */
    ValkeyModuleScriptingEngineCreateFunctionsLibraryFunc create_functions_library;

    /* The callback function called when `FCALL` command is called on a function
     * registered in this engine. */
    ValkeyModuleScriptingEngineCallFunctionFunc call_function;

    /* Function callback to get current used memory by the engine. */
    ValkeyModuleScriptingEngineGetUsedMemoryFunc get_used_memory;

    /* Function callback to return memory overhead for a given function. */
    ValkeyModuleScriptingEngineGetFunctionMemoryOverheadFunc get_function_memory_overhead;

    /* Function callback to return memory overhead of the engine. */
    ValkeyModuleScriptingEngineGetEngineMemoryOverheadFunc get_engine_memory_overhead;

    /* Function callback to free the memory of a registered engine function. */
    ValkeyModuleScriptingEngineFreeFunctionFunc free_function;
} ValkeyModuleScriptingEngineMethodsV1;

/* Registers a new scripting engine in the server.
 *
 * - `engine_name`: the name of the scripting engine. This name will match
 *   against the engine name specified in the script header using a shebang.
 *
 * - `engine_ctx`: engine specific context pointer.
 *
 * - `engine_methods`: the struct with the scripting engine callback functions
 * pointers.
 */
int ValkeyModule_RegisterScriptingEngine(ValkeyModuleCtx *ctx,
                                         const char *engine_name,
                                         void *engine_ctx,
                                         ValkeyModuleScriptingEngineMethods engine_methods);

/* Removes the scripting engine from the server.
 *
 * `engine_name` is the name of the scripting engine.
 *
 */
int ValkeyModule_UnregisterScriptingEngine(ValkeyModuleCtx *ctx, const char *engine_name);
```

---------

Signed-off-by: Ricardo Dias <[email protected]>
  • Loading branch information
rjd15372 authored Dec 21, 2024
1 parent 1c97317 commit 6adef8e
Show file tree
Hide file tree
Showing 16 changed files with 1,124 additions and 136 deletions.
205 changes: 130 additions & 75 deletions src/function_lua.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,14 @@ typedef struct luaFunctionCtx {
} luaFunctionCtx;

typedef struct loadCtx {
functionLibInfo *li;
list *functions;
monotime start_time;
size_t timeout;
} loadCtx;

typedef struct registerFunctionArgs {
sds name;
sds desc;
luaFunctionCtx *lua_f_ctx;
uint64_t f_flags;
} registerFunctionArgs;
static void luaEngineFreeFunction(ValkeyModuleCtx *module_ctx,
engineCtx *engine_ctx,
void *compiled_function);

/* Hook for FUNCTION LOAD execution.
* Used to cancel the execution in case of a timeout (500ms).
Expand All @@ -93,15 +90,42 @@ static void luaEngineLoadHook(lua_State *lua, lua_Debug *ar) {
}
}

static void freeCompiledFunc(ValkeyModuleCtx *module_ctx,
luaEngineCtx *lua_engine_ctx,
void *compiled_func) {
/* The lua engine is implemented in the core, and not in a Valkey Module */
serverAssert(module_ctx == NULL);

compiledFunction *func = compiled_func;
decrRefCount(func->name);
if (func->desc) {
decrRefCount(func->desc);
}
luaEngineFreeFunction(module_ctx, lua_engine_ctx, func->function);
zfree(func);
}

/*
* Compile a given blob and save it on the registry.
* Return a function ctx with Lua ref that allows to later retrieve the
* function from the registry.
* Compile a given script code by generating a set of compiled functions. These
* functions are also saved into the the registry of the Lua environment.
*
* Returns an array of compiled functions. The `compileFunction` struct stores a
* Lua ref that allows to later retrieve the function from the registry.
* In the `out_num_compiled_functions` parameter is returned the size of the
* array.
*
* Return NULL on compilation error and set the error to the err variable
*/
static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size_t timeout, sds *err) {
int ret = C_ERR;
static compiledFunction **luaEngineCreate(ValkeyModuleCtx *module_ctx,
engineCtx *engine_ctx,
const char *code,
size_t timeout,
size_t *out_num_compiled_functions,
char **err) {
/* The lua engine is implemented in the core, and not in a Valkey Module */
serverAssert(module_ctx == NULL);

compiledFunction **compiled_functions = NULL;
luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;

Expand All @@ -114,15 +138,15 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
lua_pop(lua, 1); /* pop the metatable */

/* compile the code */
if (luaL_loadbuffer(lua, blob, sdslen(blob), "@user_function")) {
*err = sdscatprintf(sdsempty(), "Error compiling function: %s", lua_tostring(lua, -1));
if (luaL_loadbuffer(lua, code, strlen(code), "@user_function")) {
*err = valkey_asprintf("Error compiling function: %s", lua_tostring(lua, -1));
lua_pop(lua, 1); /* pops the error */
goto done;
}
serverAssert(lua_isfunction(lua, -1));

loadCtx load_ctx = {
.li = li,
.functions = listCreate(),
.start_time = getMonotonicUs(),
.timeout = timeout,
};
Expand All @@ -133,13 +157,31 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size
if (lua_pcall(lua, 0, 0, 0)) {
errorInfo err_info = {0};
luaExtractErrorInformation(lua, &err_info);
*err = sdscatprintf(sdsempty(), "Error registering functions: %s", err_info.msg);
*err = valkey_asprintf("Error registering functions: %s", err_info.msg);
lua_pop(lua, 1); /* pops the error */
luaErrorInformationDiscard(&err_info);
listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD);
listNode *node = NULL;
while ((node = listNext(iter)) != NULL) {
freeCompiledFunc(module_ctx, lua_engine_ctx, listNodeValue(node));
}
listReleaseIterator(iter);
listRelease(load_ctx.functions);
goto done;
}

ret = C_OK;
compiled_functions =
zcalloc(sizeof(compiledFunction *) * listLength(load_ctx.functions));
listIter *iter = listGetIterator(load_ctx.functions, AL_START_HEAD);
listNode *node = NULL;
*out_num_compiled_functions = 0;
while ((node = listNext(iter)) != NULL) {
compiledFunction *func = listNodeValue(node);
compiled_functions[*out_num_compiled_functions] = func;
(*out_num_compiled_functions)++;
}
listReleaseIterator(iter);
listRelease(load_ctx.functions);

done:
/* restore original globals */
Expand All @@ -152,19 +194,23 @@ static int luaEngineCreate(void *engine_ctx, functionLibInfo *li, sds blob, size

lua_sethook(lua, NULL, 0, 0); /* Disable hook */
luaSaveOnRegistry(lua, REGISTRY_LOAD_CTX_NAME, NULL);
return ret;
return compiled_functions;
}

/*
* Invole the give function with the given keys and args
*/
static void luaEngineCall(scriptRunCtx *run_ctx,
void *engine_ctx,
static void luaEngineCall(ValkeyModuleCtx *module_ctx,
engineCtx *engine_ctx,
functionCtx *func_ctx,
void *compiled_function,
robj **keys,
size_t nkeys,
robj **args,
size_t nargs) {
/* The lua engine is implemented in the core, and not in a Valkey Module */
serverAssert(module_ctx == NULL);

luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;
luaFunctionCtx *f_ctx = compiled_function;
Expand All @@ -177,52 +223,58 @@ static void luaEngineCall(scriptRunCtx *run_ctx,

serverAssert(lua_isfunction(lua, -1));

scriptRunCtx *run_ctx = (scriptRunCtx *)func_ctx;
luaCallFunction(run_ctx, lua, keys, nkeys, args, nargs, 0);
lua_pop(lua, 1); /* Pop error handler */
}

static size_t luaEngineGetUsedMemoy(void *engine_ctx) {
static engineMemoryInfo luaEngineGetMemoryInfo(ValkeyModuleCtx *module_ctx,
engineCtx *engine_ctx) {
/* The lua engine is implemented in the core, and not in a Valkey Module */
serverAssert(module_ctx == NULL);

luaEngineCtx *lua_engine_ctx = engine_ctx;
return luaMemory(lua_engine_ctx->lua);

return (engineMemoryInfo){
.used_memory = luaMemory(lua_engine_ctx->lua),
.engine_memory_overhead = zmalloc_size(lua_engine_ctx),
};
}

static size_t luaEngineFunctionMemoryOverhead(void *compiled_function) {
static size_t luaEngineFunctionMemoryOverhead(ValkeyModuleCtx *module_ctx,
void *compiled_function) {
/* The lua engine is implemented in the core, and not in a Valkey Module */
serverAssert(module_ctx == NULL);

return zmalloc_size(compiled_function);
}

static size_t luaEngineMemoryOverhead(void *engine_ctx) {
luaEngineCtx *lua_engine_ctx = engine_ctx;
return zmalloc_size(lua_engine_ctx);
}
static void luaEngineFreeFunction(ValkeyModuleCtx *module_ctx,
engineCtx *engine_ctx,
void *compiled_function) {
/* The lua engine is implemented in the core, and not in a Valkey Module */
serverAssert(module_ctx == NULL);

static void luaEngineFreeFunction(void *engine_ctx, void *compiled_function) {
luaEngineCtx *lua_engine_ctx = engine_ctx;
lua_State *lua = lua_engine_ctx->lua;
luaFunctionCtx *f_ctx = compiled_function;
lua_unref(lua, f_ctx->lua_function_ref);
zfree(f_ctx);
}

static void luaRegisterFunctionArgsInitialize(registerFunctionArgs *register_f_args,
sds name,
sds desc,
static void luaRegisterFunctionArgsInitialize(compiledFunction *func,
robj *name,
robj *desc,
luaFunctionCtx *lua_f_ctx,
uint64_t flags) {
*register_f_args = (registerFunctionArgs){
*func = (compiledFunction){
.name = name,
.desc = desc,
.lua_f_ctx = lua_f_ctx,
.function = lua_f_ctx,
.f_flags = flags,
};
}

static void luaRegisterFunctionArgsDispose(lua_State *lua, registerFunctionArgs *register_f_args) {
sdsfree(register_f_args->name);
if (register_f_args->desc) sdsfree(register_f_args->desc);
lua_unref(lua, register_f_args->lua_f_ctx->lua_function_ref);
zfree(register_f_args->lua_f_ctx);
}

/* Read function flags located on the top of the Lua stack.
* On success, return C_OK and set the flags to 'flags' out parameter
* Return C_ERR if encounter an unknown flag. */
Expand Down Expand Up @@ -267,10 +319,11 @@ static int luaRegisterFunctionReadFlags(lua_State *lua, uint64_t *flags) {
return ret;
}

static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
static int luaRegisterFunctionReadNamedArgs(lua_State *lua,
compiledFunction *func) {
char *err = NULL;
sds name = NULL;
sds desc = NULL;
robj *name = NULL;
robj *desc = NULL;
luaFunctionCtx *lua_f_ctx = NULL;
uint64_t flags = 0;
if (!lua_istable(lua, 1)) {
Expand All @@ -287,14 +340,15 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs
err = "named argument key given to server.register_function is not a string";
goto error;
}

const char *key = lua_tostring(lua, -2);
if (!strcasecmp(key, "function_name")) {
if (!(name = luaGetStringSds(lua, -1))) {
if (!(name = luaGetStringObject(lua, -1))) {
err = "function_name argument given to server.register_function must be a string";
goto error;
}
} else if (!strcasecmp(key, "description")) {
if (!(desc = luaGetStringSds(lua, -1))) {
if (!(desc = luaGetStringObject(lua, -1))) {
err = "description argument given to server.register_function must be a string";
goto error;
}
Expand Down Expand Up @@ -335,13 +389,17 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs
goto error;
}

luaRegisterFunctionArgsInitialize(register_f_args, name, desc, lua_f_ctx, flags);
luaRegisterFunctionArgsInitialize(func,
name,
desc,
lua_f_ctx,
flags);

return C_OK;

error:
if (name) sdsfree(name);
if (desc) sdsfree(desc);
if (name) decrRefCount(name);
if (desc) decrRefCount(desc);
if (lua_f_ctx) {
lua_unref(lua, lua_f_ctx->lua_function_ref);
zfree(lua_f_ctx);
Expand All @@ -350,11 +408,12 @@ static int luaRegisterFunctionReadNamedArgs(lua_State *lua, registerFunctionArgs
return C_ERR;
}

static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
static int luaRegisterFunctionReadPositionalArgs(lua_State *lua,
compiledFunction *func) {
char *err = NULL;
sds name = NULL;
robj *name = NULL;
luaFunctionCtx *lua_f_ctx = NULL;
if (!(name = luaGetStringSds(lua, 1))) {
if (!(name = luaGetStringObject(lua, 1))) {
err = "first argument to server.register_function must be a string";
goto error;
}
Expand All @@ -369,51 +428,46 @@ static int luaRegisterFunctionReadPositionalArgs(lua_State *lua, registerFunctio
lua_f_ctx = zmalloc(sizeof(*lua_f_ctx));
lua_f_ctx->lua_function_ref = lua_function_ref;

luaRegisterFunctionArgsInitialize(register_f_args, name, NULL, lua_f_ctx, 0);
luaRegisterFunctionArgsInitialize(func, name, NULL, lua_f_ctx, 0);

return C_OK;

error:
if (name) sdsfree(name);
if (name) decrRefCount(name);
luaPushError(lua, err);
return C_ERR;
}

static int luaRegisterFunctionReadArgs(lua_State *lua, registerFunctionArgs *register_f_args) {
static int luaRegisterFunctionReadArgs(lua_State *lua, compiledFunction *func) {
int argc = lua_gettop(lua);
if (argc < 1 || argc > 2) {
luaPushError(lua, "wrong number of arguments to server.register_function");
return C_ERR;
}

if (argc == 1) {
return luaRegisterFunctionReadNamedArgs(lua, register_f_args);
return luaRegisterFunctionReadNamedArgs(lua, func);
} else {
return luaRegisterFunctionReadPositionalArgs(lua, register_f_args);
return luaRegisterFunctionReadPositionalArgs(lua, func);
}
}

static int luaRegisterFunction(lua_State *lua) {
registerFunctionArgs register_f_args = {0};
compiledFunction *func = zcalloc(sizeof(*func));

loadCtx *load_ctx = luaGetFromRegistry(lua, REGISTRY_LOAD_CTX_NAME);
if (!load_ctx) {
zfree(func);
luaPushError(lua, "server.register_function can only be called on FUNCTION LOAD command");
return luaError(lua);
}

if (luaRegisterFunctionReadArgs(lua, &register_f_args) != C_OK) {
if (luaRegisterFunctionReadArgs(lua, func) != C_OK) {
zfree(func);
return luaError(lua);
}

sds err = NULL;
if (functionLibCreateFunction(register_f_args.name, register_f_args.lua_f_ctx, load_ctx->li, register_f_args.desc,
register_f_args.f_flags, &err) != C_OK) {
luaRegisterFunctionArgsDispose(lua, &register_f_args);
luaPushError(lua, err);
sdsfree(err);
return luaError(lua);
}
listAddNodeTail(load_ctx->functions, func);

return 0;
}
Expand Down Expand Up @@ -494,16 +548,17 @@ int luaEngineInitEngine(void) {
lua_enablereadonlytable(lua_engine_ctx->lua, -1, 1); /* protect the new global table */
lua_replace(lua_engine_ctx->lua, LUA_GLOBALSINDEX); /* set new global table as the new globals */


engine *lua_engine = zmalloc(sizeof(*lua_engine));
*lua_engine = (engine){
.engine_ctx = lua_engine_ctx,
.create = luaEngineCreate,
.call = luaEngineCall,
.get_used_memory = luaEngineGetUsedMemoy,
engineMethods lua_engine_methods = {
.version = VALKEYMODULE_SCRIPTING_ENGINE_ABI_VERSION,
.create_functions_library = luaEngineCreate,
.call_function = luaEngineCall,
.get_function_memory_overhead = luaEngineFunctionMemoryOverhead,
.get_engine_memory_overhead = luaEngineMemoryOverhead,
.free_function = luaEngineFreeFunction,
.get_memory_info = luaEngineGetMemoryInfo,
};
return functionsRegisterEngine(LUA_ENGINE_NAME, lua_engine);

return functionsRegisterEngine(LUA_ENGINE_NAME,
NULL,
lua_engine_ctx,
&lua_engine_methods);
}
Loading

0 comments on commit 6adef8e

Please sign in to comment.