Skip to content

Commit

Permalink
Resolve pytype tests
Browse files Browse the repository at this point in the history
  • Loading branch information
insuhan committed Mar 10, 2022
1 parent 9dc3536 commit aa12545
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 42 deletions.
Empty file added experimental/__init__.py
Empty file.
63 changes: 29 additions & 34 deletions experimental/features.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from typing import Optional, Callable
from jax import random
from jax import numpy as np
from jax.numpy.linalg import cholesky

import jax.example_libraries.stax as ostax
import neural_tangents
from neural_tangents import stax

from pkg_resources import parse_version
if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
from neural_tangents._src.utils import utils, dataclasses
from neural_tangents._src.stax.linear import _pool_kernel, Padding
from neural_tangents._src.stax.linear import _Pooling as Pooling
else:
from neural_tangents.utils import utils, dataclasses
from neural_tangents.stax import _pool_kernel, Padding, Pooling
from neural_tangents import stax
from neural_tangents._src.utils import dataclasses
from neural_tangents._src.stax.linear import _pool_kernel, Padding
from neural_tangents._src.stax.linear import _Pooling as Pooling

from sketching import TensorSRHT2, PolyTensorSRHT
from experimental.sketching import TensorSRHT2
""" Implementation for NTK Sketching and Random Features """


Expand Down Expand Up @@ -50,11 +44,11 @@ def kappa1(x):

@dataclasses.dataclass
class Features:
nngp_feat: np.ndarray
ntk_feat: np.ndarray
nngp_feat: Optional[np.ndarray] = None
ntk_feat: Optional[np.ndarray] = None

batch_axis: int = dataclasses.field(pytree_node=False)
channel_axis: int = dataclasses.field(pytree_node=False)
batch_axis: int = 0
channel_axis: int = -1

replace = ... # type: Callable[..., 'Features']

Expand All @@ -72,7 +66,7 @@ def _inputs_to_features(x: np.ndarray,
return Features(nngp_feat=nngp_feat,
ntk_feat=ntk_feat,
batch_axis=batch_axis,
channel_axis=channel_axis)
channel_axis=channel_axis) # pytype:disable=wrong-keyword-args


# Modified the serial process of feature map blocks.
Expand All @@ -95,7 +89,7 @@ def feature_fn(k, inputs, **kwargs):

def DenseFeatures(out_dim: int,
W_std: float = 1.,
b_std: float = None,
b_std: float = 1.,
parameterization: str = 'ntk',
batch_axis: int = 0,
channel_axis: int = -1):
Expand All @@ -114,7 +108,7 @@ def kernel_fn(f: Features, input, **kwargs):
nngp_feat *= W_std
ntk_feat *= W_std

if ntk_feat.ndim == 0: # check if ntk_feat is empty
if ntk_feat.ndim == 0: # check if ntk_feat is empty
ntk_feat = nngp_feat
else:
ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis)
Expand Down Expand Up @@ -153,20 +147,21 @@ def init_fn(rng, input_shape):
ts2 = TensorSRHT2(rng=rng3,
input_dim1=ntk_feat_shape[-1],
input_dim2=feature_dim0,
sketch_dim=sketch_dim).init_sketches()
sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args
return (new_nngp_feat_shape, new_ntk_feat_shape), (W0, W1, ts2)

elif method == 'ps':
rng1, rng2, rng3 = random.split(rng, 3)
# PolySketch algorithm for arc-cosine kernel of order 0.
ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0,
poly_degree0)
# PolySketch algorithm for arc-cosine kernel of order 1.
ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1,
poly_degree1)
# TensorSRHT of degree 2 for approximating tensor product.
ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim)
return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2)
# rng1, rng2, rng3 = random.split(rng, 3)
# # PolySketch algorithm for arc-cosine kernel of order 0.
# ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0,
# poly_degree0)
# # PolySketch algorithm for arc-cosine kernel of order 1.
# ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1,
# poly_degree1)
# # TensorSRHT of degree 2 for approximating tensor product.
# ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim)
# return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2)
raise NotImplementedError

elif method == 'exact':
# The exact feature map computation is for debug.
Expand Down Expand Up @@ -199,9 +194,9 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features:
kappa0_feat).reshape(input_shape + (-1,))

elif method == 'ps':
ps0: PolyTensorSRHT = input[0]
ps1: PolyTensorSRHT = input[1]
ts2: TensorSRHT2 = input[2]
# ps0: PolyTensorSRHT = input[0]
# ps1: PolyTensorSRHT = input[1]
# ts2: TensorSRHT2 = input[2]
raise NotImplementedError

elif method == 'exact': # Exact feature extraction via Cholesky decomposition.
Expand Down Expand Up @@ -258,7 +253,7 @@ def feature_fn(f, input, **kwargs):

nngp_feat = conv2d_feat(nngp_feat, filter_size) / filter_size * W_std

if ntk_feat.ndim == 0: # check if ntk_feat is empty
if ntk_feat.ndim == 0: # check if ntk_feat is empty
ntk_feat = nngp_feat
else:
ntk_feat = conv2d_feat(ntk_feat, filter_size) / filter_size * W_std
Expand Down
12 changes: 7 additions & 5 deletions experimental/sketching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from jax import random
from jax import numpy as np
from neural_tangents._src.utils import utils, dataclasses
from neural_tangents._src.utils.typing import Optional
from neural_tangents._src.utils import dataclasses
from typing import Optional, Callable


# TensorSRHT of degree 2. This version allows different input vectors.
Expand All @@ -20,9 +20,9 @@ class TensorSRHT2:
rand_inds1: Optional[np.ndarray] = None
rand_inds2: Optional[np.ndarray] = None

replace = ...
replace = ... # type: Callable[..., 'TensorSRHT2']

def init_sketches(self):
def init_sketches(self) -> 'TensorSRHT2':
rng1, rng2, rng3, rng4 = random.split(self.rng, 4)
rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1
rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1
Expand Down Expand Up @@ -53,7 +53,8 @@ def tensorsrht(x1, x2, rand_inds, rand_signs):
return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft)


# TensorSRHT of degree p. This operates the same input vectors.
# pytype: disable=attribute-error
# TODO: Improve faster TensorSRHT.
class PolyTensorSRHT:

def __init__(self, rng, input_dim, sketch_dim, coeffs):
Expand Down Expand Up @@ -133,3 +134,4 @@ def sketch(self, x):
p = p // 2
U[j] = V[log_degree - 1][0, :, :].clone()
return U
# pytype: enable=attribute-error
4 changes: 3 additions & 1 deletion experimental/test_fc_ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from jax import random
from jax.config import config
from jax import jit
import sys
sys.path.append("./")

config.update("jax_enable_x64", True)
from neural_tangents import stax

from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial

seed = 1
n, d = 6, 4
Expand Down
5 changes: 3 additions & 2 deletions experimental/test_myrtle_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''
import sys
sys.path.append("./")
import functools
from numpy.linalg import norm
from jax.config import config
Expand All @@ -12,7 +13,7 @@
from jax import random

from neural_tangents import stax
from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features
from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features

layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
width = 1
Expand Down

0 comments on commit aa12545

Please sign in to comment.