Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OPIK-672 [SDK] Fix cyclic reference/recursion issue in json_encoder #964

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions sdks/python/src/opik/jsonable_encoder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
import dataclasses
import datetime as dt

from typing import Callable, Any, Type, Set, Tuple

import logging
from enum import Enum
from pathlib import PurePath
from types import GeneratorType
from typing import Any, Callable, Optional, Set, Tuple, Type

import pydantic

import opik.rest_api.core.datetime_utils as datetime_utils
Expand All @@ -25,18 +24,29 @@ def register_encoder_extension(obj_type: Type, encoder: Callable[[Any], Any]) ->
_ENCODER_EXTENSIONS.add((obj_type, encoder))


def jsonable_encoder(obj: Any) -> Any:
def jsonable_encoder(obj: Any, seen: Optional[Set[int]] = None) -> Any:
"""
This is a modified version of the serializer generated by Fern in rest_api.core.jsonable_encoder.
The code is simplified to serialize complex objects into a textual representation.
It also handles cyclic references to avoid infinite recursion.
"""
if seen is None:
seen = set()

if hasattr(obj, "__dict__"):
obj_id = id(obj)
if obj_id in seen:
LOGGER.debug(f"Found cyclic reference to {type(obj).__name__} id={obj_id}")
return f"<Cyclic reference to {type(obj).__name__} id={obj_id}>"
seen.add(obj_id)

try:
if dataclasses.is_dataclass(obj) or isinstance(obj, pydantic.BaseModel):
obj_dict = obj.__dict__
return jsonable_encoder(obj_dict)
return jsonable_encoder(obj_dict, seen)

if isinstance(obj, Enum):
return jsonable_encoder(obj.value)
return jsonable_encoder(obj.value, seen)
if isinstance(obj, PurePath):
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
Expand All @@ -50,26 +60,33 @@ def jsonable_encoder(obj: Any) -> Any:
allowed_keys = set(obj.keys())
for key, value in obj.items():
if key in allowed_keys:
encoded_key = jsonable_encoder(key)
encoded_value = jsonable_encoder(value)
encoded_key = jsonable_encoder(key, seen)
encoded_value = jsonable_encoder(value, seen)
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list = []
for item in obj:
encoded_list.append(jsonable_encoder(item))
encoded_list.append(jsonable_encoder(item, seen))
return encoded_list

for type_, encoder in _ENCODER_EXTENSIONS:
if isinstance(obj, type_):
return jsonable_encoder(encoder(obj))
return jsonable_encoder(encoder(obj), seen)

if np is not None and isinstance(obj, np.ndarray):
return jsonable_encoder(obj.tolist())
return jsonable_encoder(obj.tolist(), seen)

except Exception:
LOGGER.debug("Failed to serialize object.", exc_info=True)

finally:
# Once done encoding this object, remove from `seen`,
# so the same object can appear again at a sibling branch.
if hasattr(obj, "__dict__"):
obj_id = id(obj)
seen.remove(obj_id)

data = str(obj)

return data
Original file line number Diff line number Diff line change
@@ -1,14 +1,83 @@
from typing import Any
import dataclasses
from datetime import date, datetime, timezone
from threading import Lock
from typing import Any, Optional

import numpy as np
import pytest
import dataclasses

import opik.jsonable_encoder as jsonable_encoder


@dataclasses.dataclass
class Node:
value: int
child: Optional["Node"] = None


def test_jsonable_encoder__cyclic_reference():
"""
Test that the encoder detects cyclic references and does not infinitely recurse.
"""
# Create a simple two-node cycle: A -> B -> A
node_a = Node(value=1)
node_b = Node(value=2)
node_a.child = node_b
node_b.child = node_a

encoded = jsonable_encoder.jsonable_encoder(node_a)
# The exact format of the cycle marker can vary; we check that:
# 1. We get some structure for node_a (like a dict).
# 2. Inside node_a, there's a reference to node_b (a dict).
# 3. Inside node_b, there's a "cyclic reference" marker instead of a full node_a object.
print("=" * 150)
print(encoded)
assert isinstance(encoded, dict)
assert "value" in encoded
assert "child" in encoded

# node_a.child (which is node_b) should be a dict
assert isinstance(encoded["child"], dict)
assert "value" in encoded["child"]
assert "child" in encoded["child"]

# node_b.child should be the cycle marker
cycle_marker = encoded["child"]["child"]
print("=" * 150)
print(cycle_marker)
assert isinstance(
cycle_marker, str
), "Expected a string marker for cyclic reference"
assert (
"<Cyclic reference to " in cycle_marker
), "Should contain 'Cyclic reference' text"


def test_jsonable_encoder__repeated_objects_in_list():
"""
Test that the encoder handles a list of the same object repeated multiple times
without marking it as a cycle (because it isn't a cycle—just repeated references).
"""
node = Node(value=42)

# Put the same node object in a list multiple times
repeated_list = [node, node, node]

encoded = jsonable_encoder.jsonable_encoder(repeated_list)
# We expect a list of three items, each being a dict with `value` = 42, `child` = None
assert isinstance(encoded, list)
assert len(encoded) == 3

for item in encoded:
assert isinstance(item, dict)
assert item.get("value") == 42
assert item.get("child") is None

# They are distinct dictionary objects, but there is no cycle reference marker
# because there's no actual cycle. It's just repeated references of the same object.
assert all("Cyclic reference" not in str(item) for item in encoded)


@pytest.mark.parametrize(
"obj",
[
Expand Down
Loading