Skip to content

Commit

Permalink
Rewrite upsample with convtraspose / jinc
Browse files Browse the repository at this point in the history
  • Loading branch information
jun committed Jun 22, 2022
1 parent 3715b52 commit f1fddd5
Show file tree
Hide file tree
Showing 18 changed files with 145 additions and 122 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ For custom torch users, `pip` will not check torch version.
- [x] devide 1d and 2d modules
- [x] pip packaging
- [ ] documentation
- [ ] rewrite upsample/downsample
- [ ] apply jinc(torch.special.i1) with updated torch
- [x] rewrite upsample
- [x] apply jinc(torch.special.i1) with updated torch

## Test results 1d
| Filter sine | Filter noise |
Expand Down
Binary file modified asset/down10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/down100.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/downsample2d2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/downsample2d4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/up2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/up256.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/up2d2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified asset/up2d8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@


from .alias_free_torch import *
6 changes: 3 additions & 3 deletions src/alias_free_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .filter import LowPassFilter1d, LowPassFilter2d
from .resample import UpSample1d, UpSample2d
from .act import Activation1d, Activation2d
from .filter import *
from .resample import *
from .act import *
167 changes: 95 additions & 72 deletions src/alias_free_torch/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,52 @@ def sinc(x: torch.Tensor):
torch.sin(math.pi * x) / math.pi / x)


if 'i1' in dir(torch.special):
i1 = torch.speical.i1

def jinc(x: torch.Tensor):
return torch.where(
x == 0, torch.tensor(0.25 / math.pi,
device=x.device,
dtype=x.dtype), 1 / (2 * math.pi * x) * i1(x))
else:
jinc = sinc

# This code is adopted from adefossez's julius.lowpass.LowPassFilters
# https://adefossez.github.io/julius/julius/lowpass.html


#return filter [1,1,kernel_size]
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2

#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
#ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
if even:
time = (torch.arange(-half_size, half_size) + 0.5)
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter


class LowPassFilter1d(nn.Module):
def __init__(self,
cutoff=0.5,
Expand All @@ -34,36 +78,13 @@ def __init__(self,
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.stride = stride
self.pad = pad
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.half_size = kernel_size // 2
self.stride = stride
self.pad = pad

#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (self.half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
#ratio = 0.5/cutroff
if self.even:
time = (torch.arange(-self.half_size, self.half_size) + 0.5)
else:
time = torch.arange(self.kernel_size) - self.half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, self.kernel_size)
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)

#input [B,T] or [B,C,T]
Expand All @@ -76,13 +97,55 @@ def forward(self, x):
mode='constant',
value=0) # empirically, it is better than replicate
#mode='replicate')

out = F.conv1d(x, self.filter, stride=self.stride)
if self.even:
out = F.conv1d(x, self.filter, stride=self.stride)[..., :-1]
else:
out = F.conv1d(x, self.filter, stride=self.stride)
out = out[..., :-1]
return out.reshape(new_shape)


def kaiser_jinc_filter2d(cutoff, half_width, kernel_size):
even = (kernel_size % 2 == 0)
half_size = kernel_size // 2
#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.

#rotation equivariant grid
if even:
time = torch.stack(torch.meshgrid(
torch.arange(-half_size, half_size) + 0.5,
torch.arange(-half_size, half_size) + 0.5),
dim=-1)
else:
time = torch.stack(torch.meshgrid(
torch.arange(kernel_size) - half_size,
torch.arange(kernel_size) - half_size),
dim=-1)

time = torch.norm(time, dim=-1)
#rotation equivariant window
window = torch.i0(
beta * torch.sqrt(1 - (time / half_size / 2**0.5)**2)) / torch.i0(
torch.tensor([beta]))
#ratio = 0.5/cutroff
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * jinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size, kernel_size)
return filter


class LowPassFilter2d(nn.Module):
def __init__(self,
cutoff=0.5,
Expand All @@ -96,51 +159,13 @@ def __init__(self,
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.stride = stride
self.pad = pad
self.kernel_size = kernel_size
self.even = (kernel_size % 2 == 0)
self.half_size = kernel_size // 2
self.stride = stride
self.pad = pad

#For kaiser window
delta_f = 4 * half_width
A = 2.285 * (self.half_size - 1) * math.pi * delta_f + 7.95
if A > 50.:
beta = 0.1102 * (A - 8.7)
elif A >= 21.:
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
else:
beta = 0.

#rotation equivariant grid
if self.even:
time = torch.stack(torch.meshgrid(
torch.arange(-self.half_size, self.half_size) + 0.5,
torch.arange(-self.half_size, self.half_size) + 0.5),
dim=-1)
else:
time = torch.stack(torch.meshgrid(
torch.arange(self.kernel_size) - self.half_size,
torch.arange(self.kernel_size) - self.half_size),
dim=-1)

time = torch.norm(time, dim=-1)
#rotation equivariant window
window = torch.i0(
beta * torch.sqrt(1 -
(time / self.half_size / 2**0.5)**2)) / torch.i0(
torch.tensor([beta]))
#ratio = 0.5/cutroff
#using sinc instead jinc
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
# Normalize filter to have sum = 1, otherwise we will have a small leakage
# of the constant component in the input signal.
filter_ /= filter_.sum()
filter = filter_.view(1, 1, self.kernel_size, self.kernel_size)
filter = kaiser_jinc_filter2d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)

#input [B,C,W,H] or [B,W,H] or [W,H]
Expand All @@ -154,10 +179,8 @@ def forward(self, x):
mode='constant',
value=0) # empirically, it is better than replicate or reflect
#mode='replicate')
out = F.conv2d(x, self.filter, stride=self.stride)
if self.even:
out = F.conv2d(x, self.filter, stride=self.stride)[..., :-1, :-1]
else:
out = F.conv2d(x, self.filter, stride=self.stride)

out = out[..., :-1, :-1]
new_shape = shape[:-2] + list(out.shape)[-2:]
return out.reshape(new_shape)
51 changes: 27 additions & 24 deletions src/alias_free_torch/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
import torch.nn as nn
import torch.nn.functional as F
from .filter import LowPassFilter1d, LowPassFilter2d
from .filter import kaiser_sinc_filter1d, kaiser_jinc_filter2d


class UpSample1d(nn.Module):
def __init__(self, ratio=2):
def __init__(self, ratio=2, even=True):
super().__init__()
self.ratio = ratio
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=int(6 * ratio // 2) * 2)
self.even = even
kernel_size = int(6 * ratio // 2) * 2 + int(not (even))
self.stride = ratio
self.pad = kernel_size // 2 - ratio // 2
self.filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=kernel_size)

def forward(self, x):
shape = list(x.shape)
new_shape = shape[:-1] + [shape[-1] * self.ratio]
xx = torch.zeros(new_shape, device=x.device)
xx[..., ::self.ratio] = x
xx = self.ratio * xx
x = self.lowpass(xx.view(new_shape))
x = self.ratio * F.conv_transpose1d(
x, self.filter, stride=self.stride, padding=self.pad)
if not self.even:
x = x[..., :-1]
return x


Expand All @@ -37,23 +40,24 @@ def forward(self, x):


class UpSample2d(nn.Module):
def __init__(self, ratio=2):
def __init__(self, ratio=2, even=True):
super().__init__()
self.ratio = ratio
self.lowpass = LowPassFilter2d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=int(6 * ratio // 2) * 2)
self.even = even
kernel_size = int(6 * ratio // 2) * 2 + int(not (even))
self.stride = ratio
self.pad = kernel_size // 2 - ratio // 2
self.filter = kaiser_jinc_filter2d(cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=kernel_size)

def forward(self, x):
shape = list(x.shape)
new_shape = shape[:-2] + [shape[-2] * self.ratio
] + [shape[-1] * self.ratio]

xx = torch.zeros(new_shape, device=x.device)
#shape + [self.ratio**2], device=x.device)
xx[..., ::self.ratio, ::self.ratio] = x
xx = self.ratio**2 * xx
x = self.lowpass(xx)
print(x.shape)
x = self.ratio**2 * F.conv_transpose2d(
x, self.filter, stride=self.stride, padding=self.pad)
if not self.even:
x = x[..., :-1, :-1]
print(x.shape)
return x


Expand All @@ -69,4 +73,3 @@ def __init__(self, ratio=2):
def forward(self, x):
xx = self.lowpass(x)
return xx

8 changes: 4 additions & 4 deletions test/act2d_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
import matplotlib.pyplot as plt
from alias_free_torch.act import Activation2d
from alias_free_torch.resample import UpSample2d, DownSample2d
from alias_free_torch.filter import LowPassFilter2d
from src.alias_free_torch.act import Activation2d
from src.alias_free_torch.resample import UpSample2d, DownSample2d
from src.alias_free_torch.filter import LowPassFilter2d
import math
continuous_ratio = 16
ratio = 2
size = 256
center = [0, 1.5 + step / 5.]
center = [0, 1.5 + 10 / 5.]
t = (torch.stack(torch.meshgrid(
(torch.arange(-size, size) - center[0]) / size,
(torch.arange(-size, size) - center[1]) / size),
Expand Down
2 changes: 1 addition & 1 deletion test/down1d_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import matplotlib.pyplot as plt
from alias_free_torch.resample import DownSample1d
from src.alias_free_torch.resample import DownSample1d

ratio = 10
t = torch.arange(100) / 100. * 3.141592
Expand Down
4 changes: 2 additions & 2 deletions test/down2d_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import matplotlib.pyplot as plt
from alias_free_torch.resample import DownSample2d
from src.alias_free_torch.resample import DownSample2d

ratio = 4
ratio = 2
size = 80
t = (torch.stack(torch.meshgrid(
torch.arange(-size, size) + 0.5,
Expand Down
9 changes: 4 additions & 5 deletions test/up1d_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import matplotlib.pyplot as plt
from alias_free_torch.resample import UpSample1d
from alias_free_torch.filter import LowPassFilter1d
from src.alias_free_torch.resample import UpSample1d
from src.alias_free_torch.filter import LowPassFilter1d

ratio = 2
ratio = 256
t = torch.arange(100) / 10. * 3.141592
tt = torch.arange(100 * ratio) / (10. * ratio) * 3.141592
#low = LowPassFilter1d(cutoff = 0.5/ratio/ratio,
Expand All @@ -12,8 +12,7 @@
orig_sin = torch.sin(t) + torch.sin(t * 2)
real_up_sin = torch.sin(tt) + torch.sin(tt * 2)
upsample = UpSample1d(ratio)
print(upsample.lowpass.filter)
up_sin = (upsample(orig_sin))
up_sin = (upsample(orig_sin.view(1,1,100))).view(100*ratio)
#up_sin = low(upsample(orig_sin))

plt.figure(figsize=(7, 5))
Expand Down
Loading

0 comments on commit f1fddd5

Please sign in to comment.