Skip to content

Commit

Permalink
update roll (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng-Bicheng authored Apr 13, 2024
1 parent 83ecf73 commit a0b7c04
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 3 deletions.
14 changes: 12 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

*.pyc
.pydevproject
build/*
.eggs/*
dist/*
.setuptools*
Expand All @@ -13,4 +12,15 @@ paddle2onnx.egg-info/*
*_*.onnx
*.log
version.py
paddle2onnx/mappers_registry.h
paddle2onnx/mappers_registry.h

# CMD
build/*
paddle2onnx-*

# Clion
cmake-build-*
.idea

# VSCode
.vscode
2 changes: 1 addition & 1 deletion VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0
1.2.1
63 changes: 63 additions & 0 deletions paddle2onnx/mapper/tensor/roll.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <limits>
#include "paddle2onnx/mapper/tensor/roll.h"

namespace paddle2onnx {
REGISTER_MAPPER(roll, RollMapper)

void RollMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");

std::vector<int64_t> shifts;
GetAttr("shifts", &shifts);

std::vector<int64_t> axis;
GetAttr("axis", &axis);

std::shared_ptr<ONNX_NAMESPACE::NodeProto> temp_node= nullptr;
auto result_name = input_info[0].name;
if (axis.empty())
{
int64_t axes = 0;
result_name = helper_->Flatten(result_name);
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
helper_->Reshape(result_name, output_info[0].name, input_info[0].shape);
// helper_->MakeNode("Reshape", {result_name, input_info[0].shape}, {output_info[0].name});
} else {
for(int i = 0;i < shifts.size();i++) {
auto shift = shifts[i];
int64_t axes = axis[i];
auto result_0 = helper_->Slice(result_name, {axes}, {-shift}, {(std::numeric_limits<int64_t>::max)()});
auto result_1 = helper_->Slice(result_name, {axes}, {0}, {-shift});
if(i+1 == shifts.size()) {
temp_node = helper_->MakeNode("Concat", {result_0, result_1}, {output_info[0].name});
} else {
temp_node = helper_->MakeNode("Concat", {result_0, result_1});
}
AddAttribute(temp_node, "axis", axes);
result_name = temp_node->output(0);
}
}
}
} // namespace paddle2onnx
31 changes: 31 additions & 0 deletions paddle2onnx/mapper/tensor/roll.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>
#include <vector>

#include "paddle2onnx/mapper/mapper.h"

namespace paddle2onnx {

class RollMapper : public Mapper {
public:
RollMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
void Opset7();
};

} // namespace paddle2onnx
54 changes: 54 additions & 0 deletions tests/test_roll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from onnxbase import APIOnnx
from onnxbase import randtool


class Net(paddle.nn.Layer):
"""
simple Net
"""

def __init__(self):
super(Net, self).__init__()

def forward(self, inputs):
"""
forward
"""
x = paddle.roll(inputs, 1)
return x


def test_roll():
"""
api: paddle.roll
op version: 9
"""
op = Net()
op.eval()
# net, name, ver_list, delta=1e-6, rtol=1e-5
obj = APIOnnx(op, 'roll', [9])
input_data = paddle.to_tensor(randtool("float", -1, 1, [2,2]).astype('float32'))
print(input_data)
obj.set_input_data(
"input_data",
input_data
)
obj.run()

if __name__ == "__main__":
test_roll()

0 comments on commit a0b7c04

Please sign in to comment.