diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/features.py b/experimental/features.py index 196236da..faeb2c2f 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -1,21 +1,15 @@ +from typing import Optional, Callable from jax import random from jax import numpy as np from jax.numpy.linalg import cholesky - import jax.example_libraries.stax as ostax -import neural_tangents -from neural_tangents import stax -from pkg_resources import parse_version -if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'): - from neural_tangents._src.utils import utils, dataclasses - from neural_tangents._src.stax.linear import _pool_kernel, Padding - from neural_tangents._src.stax.linear import _Pooling as Pooling -else: - from neural_tangents.utils import utils, dataclasses - from neural_tangents.stax import _pool_kernel, Padding, Pooling +from neural_tangents import stax +from neural_tangents._src.utils import dataclasses +from neural_tangents._src.stax.linear import _pool_kernel, Padding +from neural_tangents._src.stax.linear import _Pooling as Pooling -from sketching import TensorSRHT2, PolyTensorSRHT +from experimental.sketching import TensorSRHT2 """ Implementation for NTK Sketching and Random Features """ @@ -50,11 +44,11 @@ def kappa1(x): @dataclasses.dataclass class Features: - nngp_feat: np.ndarray - ntk_feat: np.ndarray + nngp_feat: Optional[np.ndarray] = None + ntk_feat: Optional[np.ndarray] = None - batch_axis: int = dataclasses.field(pytree_node=False) - channel_axis: int = dataclasses.field(pytree_node=False) + batch_axis: int = 0 + channel_axis: int = -1 replace = ... # type: Callable[..., 'Features'] @@ -72,7 +66,7 @@ def _inputs_to_features(x: np.ndarray, return Features(nngp_feat=nngp_feat, ntk_feat=ntk_feat, batch_axis=batch_axis, - channel_axis=channel_axis) + channel_axis=channel_axis) # pytype:disable=wrong-keyword-args # Modified the serial process of feature map blocks. @@ -95,7 +89,7 @@ def feature_fn(k, inputs, **kwargs): def DenseFeatures(out_dim: int, W_std: float = 1., - b_std: float = None, + b_std: float = 1., parameterization: str = 'ntk', batch_axis: int = 0, channel_axis: int = -1): @@ -114,7 +108,7 @@ def kernel_fn(f: Features, input, **kwargs): nngp_feat *= W_std ntk_feat *= W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = np.concatenate((ntk_feat, nngp_feat), axis=channel_axis) @@ -153,20 +147,21 @@ def init_fn(rng, input_shape): ts2 = TensorSRHT2(rng=rng3, input_dim1=ntk_feat_shape[-1], input_dim2=feature_dim0, - sketch_dim=sketch_dim).init_sketches() + sketch_dim=sketch_dim).init_sketches() # pytype:disable=wrong-keyword-args return (new_nngp_feat_shape, new_ntk_feat_shape), (W0, W1, ts2) elif method == 'ps': - rng1, rng2, rng3 = random.split(rng, 3) - # PolySketch algorithm for arc-cosine kernel of order 0. - ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, - poly_degree0) - # PolySketch algorithm for arc-cosine kernel of order 1. - ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, - poly_degree1) - # TensorSRHT of degree 2 for approximating tensor product. - ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) - return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) + # rng1, rng2, rng3 = random.split(rng, 3) + # # PolySketch algorithm for arc-cosine kernel of order 0. + # ps0 = PolyTensorSRHT(rng1, nngp_feat_shape[-1], poly_sketch_dim0, + # poly_degree0) + # # PolySketch algorithm for arc-cosine kernel of order 1. + # ps1 = PolyTensorSRHT(rng2, nngp_feat_shape[-1], poly_sketch_dim1, + # poly_degree1) + # # TensorSRHT of degree 2 for approximating tensor product. + # ts2 = TensorSRHT2(rng3, ntk_feat_shape[-1], feature_dim0, sketch_dim) + # return (new_nngp_feat_shape, new_ntk_feat_shape), (ps0, ps1, ts2) + raise NotImplementedError elif method == 'exact': # The exact feature map computation is for debug. @@ -199,9 +194,9 @@ def feature_fn(f: Features, input=None, **kwargs) -> Features: kappa0_feat).reshape(input_shape + (-1,)) elif method == 'ps': - ps0: PolyTensorSRHT = input[0] - ps1: PolyTensorSRHT = input[1] - ts2: TensorSRHT2 = input[2] + # ps0: PolyTensorSRHT = input[0] + # ps1: PolyTensorSRHT = input[1] + # ts2: TensorSRHT2 = input[2] raise NotImplementedError elif method == 'exact': # Exact feature extraction via Cholesky decomposition. @@ -258,7 +253,7 @@ def feature_fn(f, input, **kwargs): nngp_feat = conv2d_feat(nngp_feat, filter_size) / filter_size * W_std - if ntk_feat.ndim == 0: # check if ntk_feat is empty + if ntk_feat.ndim == 0: # check if ntk_feat is empty ntk_feat = nngp_feat else: ntk_feat = conv2d_feat(ntk_feat, filter_size) / filter_size * W_std diff --git a/experimental/sketching.py b/experimental/sketching.py index 29f60cf4..48b54abf 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -1,7 +1,7 @@ from jax import random from jax import numpy as np -from neural_tangents._src.utils import utils, dataclasses -from neural_tangents._src.utils.typing import Optional +from neural_tangents._src.utils import dataclasses +from typing import Optional, Callable # TensorSRHT of degree 2. This version allows different input vectors. @@ -20,9 +20,9 @@ class TensorSRHT2: rand_inds1: Optional[np.ndarray] = None rand_inds2: Optional[np.ndarray] = None - replace = ... + replace = ... # type: Callable[..., 'TensorSRHT2'] - def init_sketches(self): + def init_sketches(self) -> 'TensorSRHT2': rng1, rng2, rng3, rng4 = random.split(self.rng, 4) rand_signs1 = random.choice(rng1, 2, shape=(self.input_dim1,)) * 2 - 1 rand_signs2 = random.choice(rng2, 2, shape=(self.input_dim2,)) * 2 - 1 @@ -53,7 +53,8 @@ def tensorsrht(x1, x2, rand_inds, rand_signs): return np.sqrt(1 / rand_inds.shape[1]) * (x1fft * x2fft) -# TensorSRHT of degree p. This operates the same input vectors. +# pytype: disable=attribute-error +# TODO: Improve faster TensorSRHT. class PolyTensorSRHT: def __init__(self, rng, input_dim, sketch_dim, coeffs): @@ -133,3 +134,4 @@ def sketch(self, x): p = p // 2 U[j] = V[log_degree - 1][0, :, :].clone() return U +# pytype: enable=attribute-error \ No newline at end of file diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index 23e5d575..c340deac 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -2,11 +2,13 @@ from jax import random from jax.config import config from jax import jit +import sys +sys.path.append("./") config.update("jax_enable_x64", True) from neural_tangents import stax -from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial +from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial seed = 1 n, d = 6, 4 diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py index 29d34c73..c89c8046 100644 --- a/experimental/test_myrtle_network.py +++ b/experimental/test_myrtle_network.py @@ -1,6 +1,7 @@ import os - os.environ['CUDA_VISIBLE_DEVICES'] = '' +import sys +sys.path.append("./") import functools from numpy.linalg import norm from jax.config import config @@ -12,7 +13,7 @@ from jax import random from neural_tangents import stax -from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features +from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} width = 1