From bfcc5d2255c9599765d9a6c758a705582ce72147 Mon Sep 17 00:00:00 2001 From: Jason Penilla <11360596+jpenilla@users.noreply.github.com> Date: Tue, 14 Nov 2023 19:42:17 -0700 Subject: [PATCH] Add proxy inheritance (#9) --- .../internal/util/Util.java | 25 ++++ .../proxy/ReflectionProxyFactory.java | 2 - .../ReflectionProxyInvocationHandler.java | 89 ++++++++----- .../ReflectionProxyInheritanceTest.java | 124 ++++++++++++++++++ ...pperTest.java => ReflectionProxyTest.java} | 8 +- 5 files changed, 208 insertions(+), 40 deletions(-) create mode 100644 src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyInheritanceTest.java rename src/test/java/xyz/jpenilla/reflectionremapper/{ReflectionRemapperTest.java => ReflectionProxyTest.java} (97%) diff --git a/src/main/java/xyz/jpenilla/reflectionremapper/internal/util/Util.java b/src/main/java/xyz/jpenilla/reflectionremapper/internal/util/Util.java index 138e885..439d642 100644 --- a/src/main/java/xyz/jpenilla/reflectionremapper/internal/util/Util.java +++ b/src/main/java/xyz/jpenilla/reflectionremapper/internal/util/Util.java @@ -22,6 +22,12 @@ import java.lang.invoke.MethodType; import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; import java.util.function.UnaryOperator; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; @@ -130,6 +136,25 @@ public static MethodHandle handleForDefaultMethod( ); } + public static List> topDownInterfaceHierarchy(final Class cls) { + if (!cls.isInterface()) { + throw new IllegalStateException("Expected an interface, got " + cls); + } + final Set> set = new LinkedHashSet<>(); + set.add(cls); + interfaces(cls, set); + final List> list = new ArrayList<>(set); + Collections.reverse(list); + return Collections.unmodifiableList(list); + } + + private static void interfaces(final Class cls, final Collection> list) { + for (final Class iface : cls.getInterfaces()) { + list.add(iface); + interfaces(iface, list); + } + } + public static String descriptorString(final Class clazz) { if (DESCRIPTOR_STRING != null) { // jdk 12+ diff --git a/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyFactory.java b/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyFactory.java index a39b2d9..df3b17b 100644 --- a/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyFactory.java +++ b/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyFactory.java @@ -21,7 +21,6 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.framework.qual.DefaultQualifier; import xyz.jpenilla.reflectionremapper.ReflectionRemapper; -import xyz.jpenilla.reflectionremapper.internal.util.Util; import xyz.jpenilla.reflectionremapper.proxy.annotation.Proxies; /** @@ -53,7 +52,6 @@ public I reflectionProxy(final Class proxyInterface) { new Class[]{proxyInterface}, new ReflectionProxyInvocationHandler<>( proxyInterface, - Util.findProxiedClass(proxyInterface, this.reflectionRemapper::remapClassName), this.reflectionRemapper ) ); diff --git a/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyInvocationHandler.java b/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyInvocationHandler.java index 703aa05..f49d0a1 100644 --- a/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyInvocationHandler.java +++ b/src/main/java/xyz/jpenilla/reflectionremapper/proxy/ReflectionProxyInvocationHandler.java @@ -47,7 +47,6 @@ final class ReflectionProxyInvocationHandler implements InvocationHandler { private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); private static final Object[] EMPTY_OBJECT_ARRAY = new Object[]{}; private final Class interfaceClass; - private final Class proxiedClass; private final Map methods = new HashMap<>(); private final Map getters = new HashMap<>(); private final Map setters = new HashMap<>(); @@ -57,11 +56,9 @@ final class ReflectionProxyInvocationHandler implements InvocationHandler { ReflectionProxyInvocationHandler( final Class interfaceClass, - final Class proxiedClass, final ReflectionRemapper reflectionRemapper ) { this.interfaceClass = interfaceClass; - this.proxiedClass = proxiedClass; this.scanInterface(reflectionRemapper); } @@ -76,7 +73,7 @@ final class ReflectionProxyInvocationHandler implements InvocationHandler { } else if (isHashCodeMethod(method)) { return 0; } else if (isToStringMethod(method)) { - return String.format("ReflectionProxy[interface=%s, implementation=%s, proxies=%s]", this.interfaceClass.getTypeName(), proxy.getClass().getTypeName(), this.proxiedClass.getTypeName()); + return String.format("ReflectionProxy[interface=%s, implementation=%s]", this.interfaceClass.getTypeName(), proxy.getClass().getTypeName()); } if (args == null) { @@ -142,19 +139,40 @@ final class ReflectionProxyInvocationHandler implements InvocationHandler { } private void scanInterface(final ReflectionRemapper reflectionRemapper) { - this.scanInterface( - reflectionRemapper::remapClassOrArrayName, - fieldName -> reflectionRemapper.remapFieldName(this.proxiedClass, fieldName), - (methodName, parameters) -> reflectionRemapper.remapMethodName(this.proxiedClass, methodName, parameters) - ); + Class prevProxy = null; + Class prevProxied = null; + + for (final Class cls : Util.topDownInterfaceHierarchy(this.interfaceClass)) { + final Class proxied = Util.findProxiedClass(cls, reflectionRemapper::remapClassName); + + if (prevProxied != null && !prevProxied.isAssignableFrom(proxied)) { + throw new IllegalArgumentException( + "Reflection proxy interface " + cls.getName() + " proxies " + proxied.getName() + ", and extends from reflection proxy interface " + + prevProxy.getName() + " which proxies " + prevProxied.getName() + ", but the proxied types are not compatible." + ); + } + + this.scanInterface( + cls, + proxied, + reflectionRemapper::remapClassOrArrayName, + fieldName -> reflectionRemapper.remapFieldName(proxied, fieldName), + (methodName, parameters) -> reflectionRemapper.remapMethodName(proxied, methodName, parameters) + ); + + prevProxied = proxied; + prevProxy = cls; + } } private void scanInterface( + final Class interfaceClass, + final Class proxiedClass, final UnaryOperator classMapper, final UnaryOperator fieldMapper, final BiFunction[], String> methodMapper ) { - for (final Method method : this.interfaceClass.getDeclaredMethods()) { + for (final Method method : interfaceClass.getDeclaredMethods()) { if (isEqualsMethod(method) || isHashCodeMethod(method) || isToStringMethod(method) || Util.isSynthetic(method.getModifiers())) { continue; } else if (method.isDefault()) { @@ -167,7 +185,7 @@ private void scanInterface( if (constructorInvoker) { this.methods.put( method, - adapt(Util.sneakyThrows(() -> LOOKUP.unreflectConstructor(this.findProxiedConstructor(method, classMapper)))) + adapt(Util.sneakyThrows(() -> LOOKUP.unreflectConstructor(this.findProxiedConstructor(proxiedClass, method, classMapper)))) ); continue; } @@ -175,42 +193,42 @@ private void scanInterface( final @Nullable FieldGetter getterAnnotation = method.getDeclaredAnnotation(FieldGetter.class); final @Nullable FieldSetter setterAnnotation = method.getDeclaredAnnotation(FieldSetter.class); if (getterAnnotation != null && setterAnnotation != null) { - throw new IllegalArgumentException("Method " + method.getName() + " in " + this.interfaceClass.getTypeName() + " is annotated with @FieldGetter and @FieldSetter, don't know which to use."); + throw new IllegalArgumentException("Method " + method.getName() + " in " + interfaceClass.getTypeName() + " is annotated with @FieldGetter and @FieldSetter, don't know which to use."); } final boolean hasStaticAnnotation = method.getDeclaredAnnotation(Static.class) != null; if (getterAnnotation != null) { - final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectGetter(this.findProxiedField(getterAnnotation.value(), fieldMapper))); + final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectGetter(findProxiedField(proxiedClass, getterAnnotation.value(), fieldMapper))); if (hasStaticAnnotation) { - checkParameterCount(method, this.interfaceClass, 0, "Static @FieldGetters should have no parameters."); + checkParameterCount(method, interfaceClass, 0, "Static @FieldGetters should have no parameters."); this.staticGetters.put(method, handle.asType(MethodType.methodType(Object.class))); } else { - checkParameterCount(method, this.interfaceClass, 1, "Non-static @FieldGetters should have one parameter."); + checkParameterCount(method, interfaceClass, 1, "Non-static @FieldGetters should have one parameter."); this.getters.put(method, handle.asType(MethodType.methodType(Object.class, Object.class))); } continue; } if (setterAnnotation != null) { - final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectSetter(this.findProxiedField(setterAnnotation.value(), fieldMapper))); + final MethodHandle handle = Util.sneakyThrows(() -> LOOKUP.unreflectSetter(findProxiedField(proxiedClass, setterAnnotation.value(), fieldMapper))); if (hasStaticAnnotation) { - checkParameterCount(method, this.interfaceClass, 1, "Static @FieldSetters should have one parameter."); + checkParameterCount(method, interfaceClass, 1, "Static @FieldSetters should have one parameter."); this.staticSetters.put(method, handle.asType(MethodType.methodType(Object.class, Object.class))); } else { - checkParameterCount(method, this.interfaceClass, 2, "Non-static @FieldSetters should have two parameters."); + checkParameterCount(method, interfaceClass, 2, "Non-static @FieldSetters should have two parameters."); this.setters.put(method, handle.asType(MethodType.methodType(Object.class, Object.class, Object.class))); } continue; } if (!hasStaticAnnotation && method.getParameterCount() < 1) { - throw new IllegalArgumentException("Non-static method invokers should have at least one parameter. Method " + method.getName() + " in " + this.interfaceClass.getTypeName() + " has " + method.getParameterCount()); + throw new IllegalArgumentException("Non-static method invokers should have at least one parameter. Method " + method.getName() + " in " + interfaceClass.getTypeName() + " has " + method.getParameterCount()); } this.methods.put( method, - adapt(Util.sneakyThrows(() -> LOOKUP.unreflect(this.findProxiedMethod(method, classMapper, methodMapper)))) + adapt(Util.sneakyThrows(() -> LOOKUP.unreflect(this.findProxiedMethod(proxiedClass, method, classMapper, methodMapper)))) ); } } @@ -249,47 +267,50 @@ private static boolean isEqualsMethod(final Method method) { && method.getReturnType() == boolean.class; } - private Field findProxiedField( + private static Field findProxiedField( + final Class proxiedClass, final String fieldName, final UnaryOperator fieldMapper ) { final Field field; try { - field = this.proxiedClass.getDeclaredField(fieldMapper.apply(fieldName)); + field = proxiedClass.getDeclaredField(fieldMapper.apply(fieldName)); } catch (final NoSuchFieldException e) { - throw new IllegalArgumentException("Could not find field '" + fieldName + "' in " + this.proxiedClass.getTypeName(), e); + throw new IllegalArgumentException("Could not find field '" + fieldName + "' in " + proxiedClass.getTypeName(), e); } try { field.setAccessible(true); } catch (final Exception ex) { - throw new IllegalStateException("Could not set access for field '" + fieldName + "' in " + this.proxiedClass.getTypeName(), ex); + throw new IllegalStateException("Could not set access for field '" + fieldName + "' in " + proxiedClass.getTypeName(), ex); } return field; } private Constructor findProxiedConstructor( + final Class proxiedClass, final Method method, final UnaryOperator classMapper ) { final Class[] actualParams = Arrays.stream(method.getParameters()) - .map(p -> this.resolveParameterTypeClass(p, classMapper)) + .map(p -> resolveParameterTypeClass(p, classMapper)) .toArray(Class[]::new); final Constructor constructor; try { - constructor = this.proxiedClass.getDeclaredConstructor(actualParams); + constructor = proxiedClass.getDeclaredConstructor(actualParams); } catch (final NoSuchMethodException ex) { - throw new IllegalArgumentException("Could not find constructor of " + this.proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex); + throw new IllegalArgumentException("Could not find constructor of " + proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex); } try { constructor.setAccessible(true); } catch (final Exception ex) { - throw new IllegalStateException("Could not set access for proxy method target constructor of " + this.proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex); + throw new IllegalStateException("Could not set access for proxy method target constructor of " + proxiedClass.getTypeName() + " with parameter types " + Arrays.toString(method.getParameterTypes()), ex); } return constructor; } private Method findProxiedMethod( + final Class proxiedClass, final Method method, final UnaryOperator classMapper, final BiFunction[], String> methodMapper @@ -299,12 +320,12 @@ private Method findProxiedMethod( final Class[] actualParams; if (hasStaticAnnotation) { actualParams = Arrays.stream(method.getParameters()) - .map(p -> this.resolveParameterTypeClass(p, classMapper)) + .map(p -> resolveParameterTypeClass(p, classMapper)) .toArray(Class[]::new); } else { actualParams = Arrays.stream(method.getParameters()) .skip(1) - .map(p -> this.resolveParameterTypeClass(p, classMapper)) + .map(p -> resolveParameterTypeClass(p, classMapper)) .toArray(Class[]::new); } @@ -312,20 +333,20 @@ private Method findProxiedMethod( final String methodName = methodAnnotation == null ? method.getName() : methodAnnotation.value(); final Method proxiedMethod; try { - proxiedMethod = this.proxiedClass.getDeclaredMethod(methodMapper.apply(methodName, actualParams), actualParams); + proxiedMethod = proxiedClass.getDeclaredMethod(methodMapper.apply(methodName, actualParams), actualParams); } catch (final NoSuchMethodException e) { - throw new IllegalArgumentException("Could not find proxy method target method: " + this.proxiedClass.getTypeName() + " " + methodName); + throw new IllegalArgumentException("Could not find proxy method target method: " + proxiedClass.getTypeName() + " " + methodName); } try { proxiedMethod.setAccessible(true); } catch (final Exception ex) { - throw new IllegalStateException("Could not set access for proxy method target method: " + this.proxiedClass.getTypeName() + " " + methodName, ex); + throw new IllegalStateException("Could not set access for proxy method target method: " + proxiedClass.getTypeName() + " " + methodName, ex); } return proxiedMethod; } - private Class resolveParameterTypeClass( + private static Class resolveParameterTypeClass( final Parameter parameter, final UnaryOperator classMapper ) { diff --git a/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyInheritanceTest.java b/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyInheritanceTest.java new file mode 100644 index 0000000..4e684b1 --- /dev/null +++ b/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyInheritanceTest.java @@ -0,0 +1,124 @@ +/* + * reflection-remapper + * + * Copyright (c) 2021-2023 Jason Penilla + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package xyz.jpenilla.reflectionremapper; + +import java.nio.file.Path; +import org.junit.jupiter.api.Test; +import xyz.jpenilla.reflectionremapper.proxy.ReflectionProxyFactory; +import xyz.jpenilla.reflectionremapper.proxy.annotation.FieldGetter; +import xyz.jpenilla.reflectionremapper.proxy.annotation.Proxies; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ReflectionProxyInheritanceTest { + private ReflectionProxyFactory factory() { + return ReflectionProxyFactory.create( + ReflectionRemapper.noop(), + this.getClass().getClassLoader() + ); + } + + @Proxies(String.class) + interface StringProxy {} + + @Proxies(Path.class) + interface InvalidPathProxy extends StringProxy {} + + @Test + void testInvalidHierarchy() { + // Path does not extend String + assertThrows(IllegalArgumentException.class, () -> this.factory().reflectionProxy(InvalidPathProxy.class)); + } + + static class Level { + final int number = 50; + + String name() { + return this.level(); + } + + String level() { + return Level.class.getName(); + } + } + + static class ServerLevel extends Level { + final int number1 = 55; + + @Override + String name() { + return this.serverLevel(); + } + + String serverLevel() { + return ServerLevel.class.getName(); + } + } + + @Proxies(Level.class) + interface LevelProxy { + String name(Level instance); + + @FieldGetter("number") + int number(Level instance); + + default String test0() { + return "LP 0"; + } + + default String test1() { + return "LP 1"; + } + } + + @Proxies(ServerLevel.class) + interface ServerLevelProxy extends LevelProxy { + @FieldGetter("number1") + int number1(ServerLevel level); + + @Override + default String test1() { + return "S" + LevelProxy.super.test1(); + } + + default String test2() { + return "SLP 2"; + } + } + + @Test + void testValidHierarchy() { + final LevelProxy levelProxy = this.factory().reflectionProxy(LevelProxy.class); + final ServerLevelProxy serverLevelProxy = this.factory().reflectionProxy(ServerLevelProxy.class); + + final ServerLevel sl = new ServerLevel(); + final Level l = new Level(); + + assertEquals(levelProxy.name(l), serverLevelProxy.name(l)); + assertEquals(levelProxy.name(sl), serverLevelProxy.name(sl)); + assertEquals(levelProxy.number(sl), 50); + assertEquals(levelProxy.number(sl), serverLevelProxy.number(sl)); + assertEquals(serverLevelProxy.number1(sl), 55); + + // test default methods on proxy interfaces + assertEquals(levelProxy.test0(), serverLevelProxy.test0()); + assertEquals("S" + levelProxy.test1(), serverLevelProxy.test1()); + assertEquals("SLP 2", serverLevelProxy.test2()); + } +} diff --git a/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionRemapperTest.java b/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyTest.java similarity index 97% rename from src/test/java/xyz/jpenilla/reflectionremapper/ReflectionRemapperTest.java rename to src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyTest.java index a102118..d999655 100644 --- a/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionRemapperTest.java +++ b/src/test/java/xyz/jpenilla/reflectionremapper/ReflectionProxyTest.java @@ -31,7 +31,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -class ReflectionRemapperTest { +class ReflectionProxyTest { private ReflectionProxyFactory factory() { return ReflectionProxyFactory.create( ReflectionRemapper.noop(), @@ -89,13 +89,13 @@ void testSynthetics() { assertEquals("nothing5", proxy.get(() -> "nothing", 5).get()); } - @Proxies(className = "xyz.jpenilla.reflectionremapper.ReflectionRemapperTest$PrivateClass") + @Proxies(className = "xyz.jpenilla.reflectionremapper.ReflectionProxyTest$PrivateClass") interface PrivateClassProxy { String secret(Object instance); String useSecretClass( Object instance, - @Type(className = "xyz.jpenilla.reflectionremapper.ReflectionRemapperTest$AnotherPrivateClass") Object anotherPrivateClass + @Type(className = "xyz.jpenilla.reflectionremapper.ReflectionProxyTest$AnotherPrivateClass") Object anotherPrivateClass ); @MethodName("useSecretClass") @@ -153,7 +153,7 @@ private static int staticMethod() { } } - @Proxies(className = "xyz.jpenilla.reflectionremapper.ReflectionRemapperTest$AnotherPrivateClass") + @Proxies(className = "xyz.jpenilla.reflectionremapper.ReflectionProxyTest$AnotherPrivateClass") interface AnotherPrivateClassProxy { }