Skip to content

Commit

Permalink
Add proxy inheritance (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpenilla authored Nov 15, 2023
1 parent eb2bc39 commit bfcc5d2
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import java.lang.invoke.MethodType;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.function.UnaryOperator;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -130,6 +136,25 @@ public static MethodHandle handleForDefaultMethod(
);
}

public static List<Class<?>> topDownInterfaceHierarchy(final Class<?> cls) {
if (!cls.isInterface()) {
throw new IllegalStateException("Expected an interface, got " + cls);
}
final Set<Class<?>> set = new LinkedHashSet<>();
set.add(cls);
interfaces(cls, set);
final List<Class<?>> list = new ArrayList<>(set);
Collections.reverse(list);
return Collections.unmodifiableList(list);
}

private static void interfaces(final Class<?> cls, final Collection<Class<?>> list) {
for (final Class<?> iface : cls.getInterfaces()) {
list.add(iface);
interfaces(iface, list);
}
}

public static String descriptorString(final Class<?> clazz) {
if (DESCRIPTOR_STRING != null) {
// jdk 12+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.framework.qual.DefaultQualifier;
import xyz.jpenilla.reflectionremapper.ReflectionRemapper;
import xyz.jpenilla.reflectionremapper.internal.util.Util;
import xyz.jpenilla.reflectionremapper.proxy.annotation.Proxies;

/**
Expand Down Expand Up @@ -53,7 +52,6 @@ public <I> I reflectionProxy(final Class<I> proxyInterface) {
new Class<?>[]{proxyInterface},
new ReflectionProxyInvocationHandler<>(
proxyInterface,
Util.findProxiedClass(proxyInterface, this.reflectionRemapper::remapClassName),
this.reflectionRemapper
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ final class ReflectionProxyInvocationHandler<I> implements InvocationHandler {
private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup();
private static final Object[] EMPTY_OBJECT_ARRAY = new Object[]{};
private final Class<I> interfaceClass;
private final Class<?> proxiedClass;
private final Map<Method, MethodHandle> methods = new HashMap<>();
private final Map<Method, MethodHandle> getters = new HashMap<>();
private final Map<Method, MethodHandle> setters = new HashMap<>();
Expand All @@ -57,11 +56,9 @@ final class ReflectionProxyInvocationHandler<I> implements InvocationHandler {

ReflectionProxyInvocationHandler(
final Class<I> interfaceClass,
final Class<?> proxiedClass,
final ReflectionRemapper reflectionRemapper
) {
this.interfaceClass = interfaceClass;
this.proxiedClass = proxiedClass;
this.scanInterface(reflectionRemapper);
}

Expand All @@ -76,7 +73,7 @@ final class ReflectionProxyInvocationHandler<I> implements InvocationHandler {
} else if (isHashCodeMethod(method)) {
return 0;
} else if (isToStringMethod(method)) {
return String.format("ReflectionProxy[interface=%s, implementation=%s, proxies=%s]", this.interfaceClass.getTypeName(), proxy.getClass().getTypeName(), this.proxiedClass.getTypeName());
return String.format("ReflectionProxy[interface=%s, implementation=%s]", this.interfaceClass.getTypeName(), proxy.getClass().getTypeName());
}

if (args == null) {
Expand Down Expand Up @@ -142,19 +139,40 @@ final class ReflectionProxyInvocationHandler<I> implements InvocationHandler {
}

private void scanInterface(final ReflectionRemapper reflectionRemapper) {
this.scanInterface(
reflectionRemapper::remapClassOrArrayName,
fieldName -> reflectionRemapper.remapFieldName(this.proxiedClass, fieldName),
(methodName, parameters) -> reflectionRemapper.remapMethodName(this.proxiedClass, methodName, parameters)
);
Class<?> prevProxy = null;
Class<?> prevProxied = null;

for (final Class<?> cls : Util.topDownInterfaceHierarchy(this.interfaceClass)) {
final Class<?> proxied = Util.findProxiedClass(cls, reflectionRemapper::remapClassName);

if (prevProxied != null && !prevProxied.isAssignableFrom(proxied)) {
throw new IllegalArgumentException(
"Reflection proxy interface " + cls.getName() + " proxies " + proxied.getName() + ", and extends from reflection proxy interface "
+ prevProxy.getName() + " which proxies " + prevProxied.getName() + ", but the proxied types are not compatible."
);
}

this.scanInterface(
cls,
proxied,
reflectionRemapper::remapClassOrArrayName,
fieldName -> reflectionRemapper.remapFieldName(proxied, fieldName),
(methodName, parameters) -> reflectionRemapper.remapMethodName(proxied, methodName, parameters)
);

prevProxied = proxied;
prevProxy = cls;
}
}

private void scanInterface(
final Class<?> interfaceClass,
final Class<?> proxiedClass,
final UnaryOperator<String> classMapper,
final UnaryOperator<String> fieldMapper,
final BiFunction<String, Class<?>[], String> methodMapper
) {
for (final Method method : this.interfaceClass.getDeclaredMethods()) {
for (final Method method : interfaceClass.getDeclaredMethods()) {
if (isEqualsMethod(method) || isHashCodeMethod(method) || isToStringMethod(method) || Util.isSynthetic(method.getModifiers())) {
continue;
} else if (method.isDefault()) {
Expand All @@ -167,50 +185,50 @@ private void scanInterface(
if (constructorInvoker) {
this.methods.put(
method,
adapt(Util.sneakyThrows(() -> LOOKUP.unreflectConstructor(this.findProxiedConstructor(method, classMapper))))
adapt(Util.sneakyThrows(() -> LOOKUP.unreflectConstructor(this.findProxiedConstructor(proxiedClass, method, classMapper))))
);
continue;
}

final @Nullable FieldGetter getterAnnotation = method.getDeclaredAnnotation(FieldGetter.class);
final @Nullable FieldSetter setterAnnotation = method.getDeclaredAnnotation(FieldSetter.class);
if (getterAnnotation != null && setterAnnotation != null) {
throw new IllegalArgumentException("Method " + method.getName() + " in " + this.interfaceClass.getTypeName() + " is annotated with @FieldGetter and @FieldSetter, don't know which to use.");
throw new IllegalArgumentException("Method " + method.getName() + " in " + interfaceClass.getTypeName() + " is annotated with @FieldGetter and @FieldSetter, don't know which to use.");
}

final boolean hasStaticAnnotation = method.getDeclaredAnnotation(Static.class) != null;

if (getterAnnotation != null) {
final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectGetter(this.findProxiedField(getterAnnotation.value(), fieldMapper)));
final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectGetter(findProxiedField(proxiedClass, getterAnnotation.value(), fieldMapper)));
if (hasStaticAnnotation) {
checkParameterCount(method, this.interfaceClass, 0, "Static @FieldGetters should have no parameters.");
checkParameterCount(method, interfaceClass, 0, "Static @FieldGetters should have no parameters.");
this.staticGetters.put(method, handle.asType(MethodType.methodType(Object.class)));
} else {
checkParameterCount(method, this.interfaceClass, 1, "Non-static @FieldGetters should have one parameter.");
checkParameterCount(method, interfaceClass, 1, "Non-static @FieldGetters should have one parameter.");
this.getters.put(method, handle.asType(MethodType.methodType(Object.class, Object.class)));
}
continue;
}

if (setterAnnotation != null) {
final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectSetter(this.findProxiedField(setterAnnotation.value(), fieldMapper)));
final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectSetter(findProxiedField(proxiedClass, setterAnnotation.value(), fieldMapper)));
if (hasStaticAnnotation) {
checkParameterCount(method, this.interfaceClass, 1, "Static @FieldSetters should have one parameter.");
checkParameterCount(method, interfaceClass, 1, "Static @FieldSetters should have one parameter.");
this.staticSetters.put(method, handle.asType(MethodType.methodType(Object.class, Object.class)));
} else {
checkParameterCount(method, this.interfaceClass, 2, "Non-static @FieldSetters should have two parameters.");
checkParameterCount(method, interfaceClass, 2, "Non-static @FieldSetters should have two parameters.");
this.setters.put(method, handle.asType(MethodType.methodType(Object.class, Object.class, Object.class)));
}
continue;
}

if (!hasStaticAnnotation && method.getParameterCount() < 1) {
throw new IllegalArgumentException("Non-static method invokers should have at least one parameter. Method " + method.getName() + " in " + this.interfaceClass.getTypeName() + " has " + method.getParameterCount());
throw new IllegalArgumentException("Non-static method invokers should have at least one parameter. Method " + method.getName() + " in " + interfaceClass.getTypeName() + " has " + method.getParameterCount());
}

this.methods.put(
method,
adapt(Util.sneakyThrows(() -> LOOKUP.unreflect(this.findProxiedMethod(method, classMapper, methodMapper))))
adapt(Util.sneakyThrows(() -> LOOKUP.unreflect(this.findProxiedMethod(proxiedClass, method, classMapper, methodMapper))))
);
}
}
Expand Down Expand Up @@ -249,47 +267,50 @@ private static boolean isEqualsMethod(final Method method) {
&& method.getReturnType() == boolean.class;
}

private Field findProxiedField(
private static Field findProxiedField(
final Class<?> proxiedClass,
final String fieldName,
final UnaryOperator<String> fieldMapper
) {
final Field field;
try {
field = this.proxiedClass.getDeclaredField(fieldMapper.apply(fieldName));
field = proxiedClass.getDeclaredField(fieldMapper.apply(fieldName));
} catch (final NoSuchFieldException e) {
throw new IllegalArgumentException("Could not find field '" + fieldName + "' in " + this.proxiedClass.getTypeName(), e);
throw new IllegalArgumentException("Could not find field '" + fieldName + "' in " + proxiedClass.getTypeName(), e);
}
try {
field.setAccessible(true);
} catch (final Exception ex) {
throw new IllegalStateException("Could not set access for field '" + fieldName + "' in " + this.proxiedClass.getTypeName(), ex);
throw new IllegalStateException("Could not set access for field '" + fieldName + "' in " + proxiedClass.getTypeName(), ex);
}
return field;
}

private Constructor<?> findProxiedConstructor(
final Class<?> proxiedClass,
final Method method,
final UnaryOperator<String> classMapper
) {
final Class<?>[] actualParams = Arrays.stream(method.getParameters())
.map(p -> this.resolveParameterTypeClass(p, classMapper))
.map(p -> resolveParameterTypeClass(p, classMapper))
.toArray(Class<?>[]::new);

final Constructor<?> constructor;
try {
constructor = this.proxiedClass.getDeclaredConstructor(actualParams);
constructor = proxiedClass.getDeclaredConstructor(actualParams);
} catch (final NoSuchMethodException ex) {
throw new IllegalArgumentException("Could not find constructor of " + this.proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex);
throw new IllegalArgumentException("Could not find constructor of " + proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex);
}
try {
constructor.setAccessible(true);
} catch (final Exception ex) {
throw new IllegalStateException("Could not set access for proxy method target constructor of " + this.proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex);
throw new IllegalStateException("Could not set access for proxy method target constructor of " + proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex);
}
return constructor;
}

private Method findProxiedMethod(
final Class<?> proxiedClass,
final Method method,
final UnaryOperator<String> classMapper,
final BiFunction<String, Class<?>[], String> methodMapper
Expand All @@ -299,33 +320,33 @@ private Method findProxiedMethod(
final Class<?>[] actualParams;
if (hasStaticAnnotation) {
actualParams = Arrays.stream(method.getParameters())
.map(p -> this.resolveParameterTypeClass(p, classMapper))
.map(p -> resolveParameterTypeClass(p, classMapper))
.toArray(Class<?>[]::new);
} else {
actualParams = Arrays.stream(method.getParameters())
.skip(1)
.map(p -> this.resolveParameterTypeClass(p, classMapper))
.map(p -> resolveParameterTypeClass(p, classMapper))
.toArray(Class<?>[]::new);
}

final @Nullable MethodName methodAnnotation = method.getDeclaredAnnotation(MethodName.class);
final String methodName = methodAnnotation == null ? method.getName() : methodAnnotation.value();
final Method proxiedMethod;
try {
proxiedMethod = this.proxiedClass.getDeclaredMethod(methodMapper.apply(methodName, actualParams), actualParams);
proxiedMethod = proxiedClass.getDeclaredMethod(methodMapper.apply(methodName, actualParams), actualParams);
} catch (final NoSuchMethodException e) {
throw new IllegalArgumentException("Could not find proxy method target method: " + this.proxiedClass.getTypeName() + " " + methodName);
throw new IllegalArgumentException("Could not find proxy method target method: " + proxiedClass.getTypeName() + " " + methodName);
}
try {
proxiedMethod.setAccessible(true);
} catch (final Exception ex) {
throw new IllegalStateException("Could not set access for proxy method target method: " + this.proxiedClass.getTypeName() + " " + methodName, ex);
throw new IllegalStateException("Could not set access for proxy method target method: " + proxiedClass.getTypeName() + " " + methodName, ex);
}

return proxiedMethod;
}

private Class<?> resolveParameterTypeClass(
private static Class<?> resolveParameterTypeClass(
final Parameter parameter,
final UnaryOperator<String> classMapper
) {
Expand Down
Loading

0 comments on commit bfcc5d2

Please sign in to comment.