Skip to content

Commit

Permalink
pytorch: Add conv2d wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
SamKG committed Feb 28, 2021
1 parent c7d5dc5 commit 32167e2
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
17 changes: 16 additions & 1 deletion psyneulink/library/compositions/pytorchcomponents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from psyneulink.core import llvm as pnlvm
from psyneulink.core.globals.log import LogCondition
from psyneulink.core.components.functions.transferfunctions import Linear, Logistic, ReLU
from psyneulink.core.components.functions.transferfunctions import Linear, Logistic, ReLU, Conv2d
from psyneulink.library.compositions.pytorchllvmhelper import *

__all__ = ['PytorchMechanismWrapper', 'PytorchProjectionWrapper']
Expand Down Expand Up @@ -103,6 +103,21 @@ def get_fct_param_value(param_name):
wrapper = PytorchFunctionWrapper(func, device=device, context=context)
return wrapper

elif isinstance(function, Conv2d):
kernel = get_fct_param_value('kernel')
kernel = torch.nn.Parameter(torch.tensor(np.reshape(kernel, (1, 1, *kernel.shape)), device=device, dtype=torch.double), requires_grad=True)

stride = get_fct_param_value('stride')
padding = get_fct_param_value('padding')
dilation = get_fct_param_value('dilation')

conv2d = torch.nn.functional.conv2d
def func(x):
x = torch.reshape(x, (1, 1, *x.shape))
return conv2d(x, weight=kernel, stride=stride, padding=padding, dilation=dilation)[0][0]

wrapper = PytorchFunctionWrapper(func, learnable_params = [kernel], device=device, context=context)
return wrapper
else:
raise Exception(f"Function {function} is not currently supported in AutodiffCompositions!")

Expand Down
25 changes: 24 additions & 1 deletion tests/composition/test_autodiffcomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

import psyneulink as pnl

from psyneulink.core.components.functions.transferfunctions import Logistic
from psyneulink.core.components.functions.transferfunctions import Logistic, Conv2d
from psyneulink.core.components.functions.learningfunctions import BackPropagation
from psyneulink.core.compositions.composition import Composition
from psyneulink.core.globals import Context
from psyneulink.core.globals.keywords import TRAINING_SET
from psyneulink.core.components.mechanisms.processing.transfermechanism import TransferMechanism
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism
from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition

Expand Down Expand Up @@ -1219,6 +1220,28 @@ def test_pytorch_equivalence_with_autodiff_training_disabled_on_proj(self):

assert np.allclose(output, comparator)

def test_conv2d_pytorch_equivalence_with_autodiff_composition(self):
variable, kernel, stride, padding, dilation, target, comparator = (np.ones((2, 2)), np.ones((2, 2)), (1,1), (0,0), (1,1), np.ones((1,1)), [[0.98340034484863281250]])

il = ProcessingMechanism(name='input', function=Conv2d(default_variable=variable, kernel=kernel, stride=stride, padding=padding, dilation=dilation), default_variable=variable)
comp = AutodiffComposition(optimizer_type='adam', learning_rate=1)
comp.add_node(il)

input_set = {
'inputs': {
il: [variable]
},
'targets': {
il: [target]
}
}

results = comp.learn(
inputs=input_set,
epochs=100
)

assert np.allclose(comparator, results[-1][-1])

@pytest.mark.pytorch
@pytest.mark.actime
Expand Down

0 comments on commit 32167e2

Please sign in to comment.