From d8f9d1bac97ba677d66701d4203f5b202b193e72 Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Fri, 2 Jul 2021 17:13:49 +0530 Subject: [PATCH 1/4] Conv1d, BacthNorm1d and AvgPool1d Layers --- .gitattributes | 11 + c_reference/include/conv1d.h | 243 +++++++ c_reference/include/dscnn.h | 104 +++ c_reference/include/rnn_bricked.h | 105 +++ c_reference/include/utils.h | 88 +++ c_reference/src/Makefile | 11 +- c_reference/src/conv1d.c | 610 ++++++++++++++++++ c_reference/src/dscnn.c | 112 ++++ c_reference/src/rnn_bricked.c | 303 +++++++++ c_reference/src/utils.c | 153 +++++ c_reference/tests/Makefile | 16 +- .../conv1d_depthwise/conv_param_depth.h | 3 + .../tests/conv1d/conv1d_lr/conv_param_lr.h | 3 + .../tests/conv1d/conv1d_regular/conv_param.h | 3 + c_reference/tests/conv1d/test_conv1d.c | 130 ++++ c_reference/tests/kws/keyword_spotting_io_1.h | 3 + c_reference/tests/kws/keyword_spotting_io_2.h | 3 + c_reference/tests/kws/keyword_spotting_io_3.h | 3 + c_reference/tests/kws/postcnn_params.h | 3 + c_reference/tests/kws/precnn_params.h | 3 + c_reference/tests/kws/rnn_params.h | 3 + .../tests/kws/test_phoneme_det_cnn_rnn.c | 275 ++++++++ .../tests/rnn_bricked/rnn_bricked_io.h | 3 + c_reference/tests/rnn_bricked/rnn_params.h | 3 + .../tests/rnn_bricked/test_rnn_bricked.c | 77 +++ 25 files changed, 2268 insertions(+), 3 deletions(-) create mode 100644 c_reference/include/conv1d.h create mode 100644 c_reference/include/dscnn.h create mode 100644 c_reference/include/rnn_bricked.h create mode 100644 c_reference/src/conv1d.c create mode 100644 c_reference/src/dscnn.c create mode 100644 c_reference/src/rnn_bricked.c create mode 100644 c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h create mode 100644 c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h create mode 100644 c_reference/tests/conv1d/conv1d_regular/conv_param.h create mode 100644 c_reference/tests/conv1d/test_conv1d.c create mode 100644 c_reference/tests/kws/keyword_spotting_io_1.h create mode 100644 c_reference/tests/kws/keyword_spotting_io_2.h create mode 100644 c_reference/tests/kws/keyword_spotting_io_3.h create mode 100644 c_reference/tests/kws/postcnn_params.h create mode 100644 c_reference/tests/kws/precnn_params.h create mode 100644 c_reference/tests/kws/rnn_params.h create mode 100644 c_reference/tests/kws/test_phoneme_det_cnn_rnn.c create mode 100644 c_reference/tests/rnn_bricked/rnn_bricked_io.h create mode 100644 c_reference/tests/rnn_bricked/rnn_params.h create mode 100644 c_reference/tests/rnn_bricked/test_rnn_bricked.c diff --git a/.gitattributes b/.gitattributes index e5b821f51..dda5bfc74 100644 --- a/.gitattributes +++ b/.gitattributes @@ -60,3 +60,14 @@ c_reference/models/q_scut_head_b_face4_model/mbconv2.h filter=lfs diff=lfs merge c_reference/models/q_scut_head_b_face4_model/mbconv4.h filter=lfs diff=lfs merge=lfs -text c_reference/models/q_scut_head_b_face4_model/rnn2.h filter=lfs diff=lfs merge=lfs -text c_reference/models/q_scut_head_b_face4_model/detection2.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/keyword_spotting_io_1.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/keyword_spotting_io_2.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/keyword_spotting_io_3.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/conv1d/conv1d_regular/conv_param.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/precnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/postcnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/kws/rnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/rnn_bricked/rnn_params.h filter=lfs diff=lfs merge=lfs -text +c_reference/tests/rnn_bricked/rnn_bricked_io.h filter=lfs diff=lfs merge=lfs -text diff --git a/c_reference/include/conv1d.h b/c_reference/include/conv1d.h new file mode 100644 index 000000000..a7ed49315 --- /dev/null +++ b/c_reference/include/conv1d.h @@ -0,0 +1,243 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#ifndef __CONV1D_H__ +#define __CONV1D_H__ + +/* All the matrices/tensors are stored in the row major format + + NOTES for the conv layers +-> The conv1d & conv1d_lr layers work for all cases and can be used unconstrained. + There are no hard constraints for the parallel version, but a few points regarding its optimal usage are given below +-> Dilation = 1 (no dilation) for all cases +-> For the non-depthwise cases, store the matrices as described below. Permutation might be necessary +-> The low-rank decomposition cannot be applied to the depthwise weight matrices. This is due to the out_channels/in_channels = 0 constarint imposed by the depthwise convolution. + For full-rank this is satisfied since out_channels = in_channels + But, when the matrix is decomposed, the constarint is violated (since rank < out_channels ; rank is not divisible by in_channels) + Hence due to the decomposition being theoretically impossible, we have not provided the support + However we suggest a less-efficient alternative => First pre-compute the weights W = W2 * W1 and then use a regular conv +-> For the parallel cases, the non-overlapping cases of the convolution are computed parallelly using MatMul (since the blocked MatMul is faster) + This howver is only valid for when the filter is fully in the input. There would be no-overlapping for the edge cases + Hence the MatVec code(regular code) is used to calculate these cases + + Important points regarding parallel versions +-> Due to the above reason, the parallel layers is only recommended for large in_time inputs + This should typically be for in_time (without the padding) > 2 * num_steps_one_row + stride. Else there would not be enough time-steps to efficiently parallelise + We need at least 2 rows for a good a MatMul performace. In the worst case the starting time step would be (stride - 1). Hence we choose 2 * num_steps_one_row + stride as the threshold + For the short input cases, the code will skip the MatMul computation and use MatVec instead (but the MatMul-variable computation overhead would remain) + For such cases, the MatVec code (conv1d and conv1d_lr) would work more efficiently due to the lower RAM usage and lack of any major overheads +-> There is no support for depthwise for conv1d_parallel + The regular convolution acts on all the channels while the depthwise acts only on one channel at a time + This results in a non-contiguos memory access. MatMul would need to process multiple such time-steps, while the MatVec would only need to process one + Hence, the MatVec would be able to enter the next channel earlier and would work much faster + While the MatMul would have cache misses (when dealing with the small chache size of edge devices) +*/ + +/** + * @brief Model parameters for the 1D Convolution Layer + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] + * @var B pointer to the bias vector, original shape = [out_channels] + * @var depthwise flag for deciding between regular(=0) and depthwise(=1) conv + */ +typedef struct ConvLayers_Params { + const float* const W; + const float* const B; + unsigned depthwise; +} ConvLayers_Params; + +/** + * @brief Model definition for the 1D Convolution Layer. Currently only for dilation = 1 + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels + * @param[in] out_time number of time steps in the output + * @param[in] out_channels number of output channels for the output of the conv layer + * NOTE: out_channels = in_channels for depthwise. This is set manually in the function + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) + * @param[in] kernel_size kernel size of the conv filter + * @param[in] params weights, bias and other essential parameters used to describe the layer + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +/** + * @brief Model parameters for the 1D Parallel Convolution Layer + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] + * @var B pointer to the bias vector, original shape = [out_channels] + * @var block_size block/tile size for the cache. Used for tiled MatMul + */ +typedef struct ConvLayers_Parallel_Params { + const float* const W; + const float* const B; + unsigned block_size; +} ConvLayers_Parallel_Params; + +/** + * @brief Model definition for the 1D Parallel Convolution Layer. Currently only for dilation = 1. No depthwise. + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels + * @param[in] out_time number of time steps in the output + * @param[in] out_channels number of output channels for the output of the conv layer + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) + * @param[in] kernel_size kernel size of the conv filter + * @param[in] params weights, bias and other essential parameters used to describe the layer + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +/** + * @brief Model parameters for the 1D Low Rank Convolution Layer. + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1] + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels] + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage + */ +typedef struct ConvLayers_LR_Params { + const float* const W1; + const float* const W2; + const float* const B; + unsigned rank; +} ConvLayers_LR_Params; + +/** + * @brief Model definition for the 1D Low-Rank Convolution Layer. Currently only for dilation = 1. + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels) + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels + * @param[in] out_time number of time steps in the output + * @param[in] out_channels number of output channels for the output of the conv layer + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) + * @param[in] kernel_size kernel size of the conv filter + * @param[in] params weights, bias and other essential parameters used to describe the layer + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +/** + * @brief Model parameters for the 1D Low Rank Parallel Convolution Layer. + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1] + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels] + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage + * @var block_size_to_lr block/tile size for the cache. Used for tiled MatMul. Used for the input -> low-rank computation + * @var block_size_from_lr block/tile size for the cache. Used for tiled MatMul. Used for the low-rank -> output computation + */ +typedef struct ConvLayers_LR_Parallel_Params { + const float* const W1; + const float* const W2; + const float* const B; + unsigned rank; + unsigned block_size_to_lr; + unsigned block_size_from_lr; +} ConvLayers_LR_Parallel_Params; + +/** + * @brief Model definition for the 1D Low-Rank Parallel Convolution Layer. Currently only for dilation = 1. + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels) + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels + * @param[in] out_time number of time steps in the output + * @param[in] out_channels number of output channels for the output of the conv layer + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) + * @param[in] kernel_size kernel size of the conv filter + * @param[in] params weights, bias and other essential parameters used to describe the layer + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation); + +// Auxiliary Layers +/** + * @brief Model definition for the 1D Average Pooling Layer. Currently only for dilation = 1 + * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation + * NOTE: out_channels == in_channels for avgpool + * @param[in] out_time number of time steps in the output + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels. The output will have the same number of channels + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) + * @param[in] kernel_size kernel size of the pool filter + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal, + unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation); + +/** + * @brief Model definition for the 1D batch Normalization Layer + * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels. The output will have the same number of channels + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0 + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0 + * @param[in] affine_config whether the affine operations are applied + * if affine_config = 0, then only mean and var are used + * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var) + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0 + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0 + * @param[in] in_place in-place computation of the batchnorm i.e. the output is stored in-place of the input signal. Storage efficient + * @param[in] eps a very small +ve value to avoid division by 0. For the default value, assign = 0.00001 + */ +int batchnorm1d(float* output_signal, float* input_signal, + unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma , const float* const beta, + unsigned in_place, float eps); + +#endif diff --git a/c_reference/include/dscnn.h b/c_reference/include/dscnn.h new file mode 100644 index 000000000..541923056 --- /dev/null +++ b/c_reference/include/dscnn.h @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#ifndef __DSCNN_H__ +#define __DSCNN_H__ + +// Function pointer for the Conv layer to be passed as a parameter. (conv1d or conv1d_lr only) +typedef int (*conv_layer)(float*, unsigned, unsigned, const float*, + unsigned, unsigned, unsigned, unsigned, + const void*, unsigned, unsigned); + +/** + * @brief Model definition for the 1D Convolution block applied before the RNN + * @brief sub-layers : batchnorm1d -> conv1d_lr + * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params) + * @param[in] in_time number of time steps in the input_signal + * @param[in] in_channels number of input channels + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 + * @param[in] affine_config whether the affine operations are applied + * if affine_config = 0, then only mean and var are used + * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var) + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 + * @param[in] in_place in-place computation check for the batchnorm. Storage efficient + * @param[in] cnn_hidden hidden state/out_channels dimensions for the low-rank CNN. The final channel size of this block + * @param[in] cnn_padding padding for the low-rank CNN layer. Note: applied to both sides of the input + * @param[in] cnn_kernel_size kernel size of the low-rank CNN + * @param[in] cnn_params weights, bias and other essential parameters for the low-rank CNN + * @param[in] cnn_stride stride factor for the low-rank CNN + * @param[in] cnn_activation an integer to choose the type of activation function. + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int phon_pred_lr_cnn(float* output_signal, float* input_signal, + conv_layer cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size, + const void* cnn_params, unsigned cnn_stride, unsigned cnn_activation); + +/** + * @brief Model definition for the 1D Convolution block applied after the RNN + * @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d + * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels + * @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params) + * @param[in] in_time number of time steps in the input + * @param[in] in_channels number of input channels + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 + * @param[in] affine_config whether the affine operations are applied + * if affine_config = 0, then only mean and var are used + * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var) + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 + * @param[in] in_place in-place computation of the batchnorm. Storage efficient + * @param[in] depth_cnn_padding padding for the depth CNN layer. Note: applied to both sides of the input to the depth CNN + * @param[in] depth_cnn_kernel_size kernel size of the depth CNN + * @param[in] depth_cnn_params weights, bias and other essential parameters used to describe the depth CNN + * @param[in] depth_cnn_stride stride factor for the depth CNN + * @param[in] depth_cnn_activation an integer to choose the type of activation function. + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + * @param[in] point_cnn_hidden hidden state/out_channels dimensions for the point CNN. The final channel size of this block + * @param[in] point_cnn_padding padding for the point CNN layer. Note: applied to both sides of the input to the point CNN + * @param[in] point_cnn_kernel_size kernel size of the point CNN + * @param[in] point_cnn_params weights, bias and other essential parameters used to describe the point CNN + * @param[in] point_cnn_stride stride factor for the point CNN + * @param[in] point_cnn_activation an integer to choose the type of activation function. + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + * @param[in] pool_padding padding for the pool layer. Note: applied to both sides of the input to the pool + * @param[in] pool_kernel_size kernel size of the pool + * @param[in] pool_stride stride factor for the pool + * @param[in] pool_activation an integer to choose the type of activation function. + * 0: none + * 1: sigmoid + * 2: tanh + * 3: relu + */ +int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, + conv_layer point_cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned depth_cnn_padding, unsigned depth_cnn_kernel_size, + const void* depth_cnn_params, unsigned depth_cnn_stride, unsigned depth_cnn_activation, + unsigned point_cnn_hidden, unsigned point_cnn_padding, unsigned point_cnn_kernel_size, + const void* point_cnn_params, unsigned point_cnn_stride, unsigned point_cnn_activation, + unsigned pool_padding, unsigned pool_kernel_size, unsigned pool_stride, unsigned pool_activation); + +#endif diff --git a/c_reference/include/rnn_bricked.h b/c_reference/include/rnn_bricked.h new file mode 100644 index 000000000..adc910d42 --- /dev/null +++ b/c_reference/include/rnn_bricked.h @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#ifndef __RNN_BRICKED_H__ +#define __RNN_BRICKED_H__ + +/* All the matrices are stored in the row major format + + NOTES for using the layers +-> Single-directional Computation + While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints + 1) in_time % hop = 0 + 2) fwd_window % hop = 0 and bwd_window % hop = 0 + + Violation of the above two constraints (1 & 2), will cause segmentation faults + The layers first compute all the Wx steps and then compute Uh for all the windows parallelly + Hence, the user needs to adhered to the constraints 1 & 2 + +-> Bi-directional Computation + For bi-directional cases, there are 2 additionally constraints that would need to be followed + A) sample_first_brick and sample_last_brick = 1 + B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call + Each function will only process its given context(forward/backward). The other context will need to be called separately. + E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...) + 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...) + + The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used + Violating the conditions would cause index mis-matches or data corruption + If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result + If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only) +*/ + +/** + * @brief Model parameters for the 1D Convolution Layer + * @var W1 pointer to first low-rank component of W. shape = [rank * in_dims] + * @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank] + * @var wRank rank of W matrix + * @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden] + * @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank] + * @var uRank rank of U matrix + * @var Bg pointer to bias for sigmoid + * @var Bh pointer to bias for tanh + * @var sigmoid_zeta first weight parameter for update from input from next step + * @var sigmoid_nu second weight parameter for update from input from next step + * @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x + * @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x) + * @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h + * @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h) + */ +typedef struct BrickedFastGRNN_LR_Params { + float* W1; + float* W2; + unsigned wRank; + float* U1; + float* U2; + unsigned uRank; + float* Bg; + float* Bh; + float sigmoid_zeta; + float sigmoid_nu; + unsigned block_size_w_to_lr; + unsigned block_size_w_from_lr; + unsigned block_size_u_to_lr; + unsigned block_size_u_from_lr; +} BrickedFastGRNN_LR_Params; + +/** Forward Bricking and application of the forward RNN for an input signal + * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden + * @param[in] rnn_hidden output dimension for the current cell + * @param[in] input_signal pointer to input signal. size = in_time * in_dims + * @param[in] in_time number of input time steps. + * @param[in] in_dims input dimensions + * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick) + * @param[in] hop hop distance for between bricks + * @param[in] params pointer to the parameters for the RNN + * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. + * @param[in] sample_first_brick determine if the 1st brick should also be sampled + * -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1 + * -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1 + */ +int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_first_brick); + +/** Backward Bricking and application of the backward RNN for an input signal + * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden + * @param[in] rnn_hidden output dimension for the current cell + * @param[in] input_signal pointer to input signal. size = in_time * in_dims + * @param[in] in_time number of input time steps. + * @param[in] in_dims input dimensions + * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick) + * @param[in] hop hop distance for between bricks + * @param[in] params pointer to the parameters for the RNN + * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. + * @param[in] sample_last_brick determine if the last brick should also be sampled + * -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1 + * -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1 + */ +int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_last_brick); + +#endif diff --git a/c_reference/include/utils.h b/c_reference/include/utils.h index 26d5242ba..07438d134 100644 --- a/c_reference/include/utils.h +++ b/c_reference/include/utils.h @@ -31,6 +31,84 @@ void matVec(const float* const mat, const float* const vec, float alpha, float beta, float* const ret); +/* + Matrix-vector multiplication with a row offset + This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis + ret is of size nrows, vec is of size ncols + mat is of size nrows * ncols, stored in row major + depthwise is to change the matVec to depthwise specific convolutions + row_stride is the offset factor between two adjacent rows + Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped + For a normal matVec case, this value will be ncols + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. + For this eg ncols will be 100 and row_stride will be 400 + vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals + For a normal matVec case, this value will be 1 + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + For this ncols will be 100 and vec_stride will be 4 +*/ +void offset_matVec_conv1d(const float* mat, const float* vec, + unsigned nrows, unsigned ncols, + unsigned row_stride, unsigned vec_stride, + unsigned depthwise, float* ret); + +/* + Tiled (cache-blocked) implementation of the Matrix Multiplication + Note: If only the MatMul output is needed, then please use calloc to initialize the output + An alternative is to use malloc, followed by memset 0 + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed + matA first matrix; shape = [nrows, ncommon] + matB second matrix; shape = [ncommon, ncols] + nrows number of rows in the first matrix + ncommon number of columns in the first matrix/number of rows in the second matrix + ncols number of columns in the second matrix + total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored + total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + ret matrix multiplication output. shape = [nrows, ncols] + block_size tile/block size for optimal cache performance. A hardware specific parameter +*/ +void tiledMatMul_float(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_cols_B, + float* const ret, unsigned block_size); + +/* + Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format + The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage) + Note: If only the MatMul output is needed, then please use calloc to initialize the output + An alternative is to use malloc, followed by memset 0 + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed + matA first matrix; shape = [nrows, ncommon] + matB second matrix; shape = [ncols, ncommon] + nrows number of rows in the first matrix + ncommon number of columns in the first matrix/number of rows in the second matrix + ncols number of columns in the second matrix + total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored + total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + Since matB is transposed the columns are now the ncomm axis + ret matrix multiplication output. shape = [nrows, ncols] + block_size tile/block size for optimal cache performance. A hardware specific parameter +*/ +void transposed_tiledMatMul(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_comm_B, + float* const ret, unsigned block_size); + // scaled vector addition: ret = scalar1 * vec1 + scalar2 * vector2 void v_add(float scalar1, const float* const vec1, float scalar2, const float* const vec2, @@ -54,4 +132,14 @@ unsigned argmax(const float* const vec, unsigned len); // ret[i] = exp(input[i]) / \sum_i exp(input[i]) void softmax(const float* const input, unsigned len, float* const ret); +/* Custom non-linear layer for the phoneme detection model. It can be used for other time-series problems if necessary + output_signal pointer to the output signal, size = out_time * (in_channels / 2) + input_signal pointer to the input signal. size = in_time * in_channels + in_time number of input time steps + in_channels number of input channels. The output will have the half the number of input channels. + Necessary for in_channels % 2 == 0 + */ +void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, + unsigned in_time, unsigned in_channels); + #endif diff --git a/c_reference/src/Makefile b/c_reference/src/Makefile index 8fc27bd65..7f7e79941 100644 --- a/c_reference/src/Makefile +++ b/c_reference/src/Makefile @@ -6,7 +6,13 @@ include ../config.mk INCLUDE_DIR=../include IFLAGS = -I $(INCLUDE_DIR) -all: utils.o fastgrnn.o classifier.o rnnpool.o quantized_utils.o quantized_fastgrnn.o quantized_rnnpool.o quantized_mbconv.o +all: dscnn.o conv1d.o utils.o fastgrnn.o classifier.o rnnpool.o quantized_utils.o quantized_fastgrnn.o quantized_rnnpool.o quantized_mbconv.o rnn_bricked.o + +dscnn.o : dscnn.c + $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ + +conv1d.o : conv1d.c + $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ utils.o: utils.c $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ @@ -20,6 +26,9 @@ classifier.o: classifier.c rnnpool.o: rnnpool.c $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ +rnn_bricked.o: rnn_bricked.c + $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ + quantized_utils.o: quantized_utils.c $(CC) -o $@ $(IFLAGS) $(CFLAGS) -c $^ diff --git a/c_reference/src/conv1d.c b/c_reference/src/conv1d.c new file mode 100644 index 000000000..2ab5b7f30 --- /dev/null +++ b/c_reference/src/conv1d.c @@ -0,0 +1,610 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "conv1d.h" +#include "utils.h" + +int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params; + + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding + unsigned rank = tparams->rank; + // Buffer for W2 out + float* temp_rank_out = (float*)malloc(rank * sizeof(float)); + // Buffer for W1 out + float* temp_out = (float*)malloc(out_channels * sizeof(float)); + for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + unsigned t_index = t_out * out_channels; + + if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { + // Filter fully inside the input. Kept as the initial condition, since this is the most common one + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if ((t_in_start < padding) && (t_in_end >= padding)) { + // Filter partially entered the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, + input_signal, rank, + (t_in_end - padding + 1) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, (in_time + padding - t_in_start) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely in the padding region + // The filter is either fully outside the input or has not yet entered the input + memset(output_signal + t_index, 0, out_channels * sizeof(float)); + } + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + free(temp_rank_out); + return 0; +} + +int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; + // Calculate the number of time steps in one row for the first non-overlapping instance + while (num_steps_one_row < kernel_size) { + num_steps_one_row += stride; + num_iter++; + } + unsigned total_in_cols = num_steps_one_row * in_channels; + + const ConvLayers_LR_Parallel_Params* tparams = (ConvLayers_LR_Parallel_Params*) params; + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Buffer to hold the output. For corner cases, this will be realtively big. + // But will be needed for the central condition (filter inside input). + // If there are not enough time steps to linearise into one row, then allocate only 1 time step + unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? + in_time / num_steps_one_row : 1; + unsigned rank = tparams->rank; + // Buffer for W2 out + float* temp_rank_out = (float*)malloc(buffer_steps * rank * sizeof(float)); + // Buffer for W1 out + float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); + + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here + for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_in_start < padding && t_out < out_time; + t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_end < padding) { + // Filter outside the input region and in the padded region + memset(output_signal + t_out * out_channels, 0, + out_channels * sizeof(float)); + } + else { //(t_in_end >= padding) + // Filter partially entered the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, + input_signal, rank, (t_in_end - padding + 1) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + } + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied) + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria + // The MatVec will be used for the computation in-case the following block is skipped + if (in_time > ((num_steps_one_row << 1) + stride)) { + t_in_start -= padding; // remove the padding offset temporarily + t_in_end -= padding; // Used to keep track of the final processed index + for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); + iter++, t_in_start += stride, t_out++) { + unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; + memset(temp_rank_out, 0, buffer_steps * rank * sizeof(float)); + memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); + if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { + // t_in_end is used to find the furthest time step was used in the MatMul calculation + // This value will be used for calculating the index for the final section of the processing + t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; + } + transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W2, + in_rows, ncols, rank, total_in_cols, ncols, + temp_rank_out, tparams->block_size_to_lr); + transposed_tiledMatMul(temp_rank_out , tparams->W1, + in_rows, rank, out_channels, rank, rank, + temp_out, tparams->block_size_from_lr); + // Copy all the data into the output + float* output_offset = (float*)output_signal + t_out * out_channels; + float* temp_offset = (float*)temp_out; + unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; + while (t_iter--) { + memcpy(output_offset, temp_offset, out_channels * sizeof(float)); + output_offset += offset_factor_for_out; + temp_offset += out_channels; + } + } + // Initialize the time iterators + // Use the stored value in t_in_end to calculate the iterators + t_in_start = t_in_end + padding; // Add the padding and stride offsets again + t_in_end = t_in_start + kernel_size - 1; + t_out = t_in_start / stride; + } + for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { + // Filter fully in the input but very close to the edges. + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped + // Incase the MatMul is skipped, this block will be used to compute the results + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W2, + input_signal + (t_in_start - padding) * in_channels, + rank, (in_time + padding - t_in_start) * in_channels, + kernel_size * in_channels, 1, 0, temp_rank_out); + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, + rank, rank, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely outside the input and in the padding region + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); + } + } + // Bias and activation + for (t_out = 0; t_out < out_time; t_out++) { + unsigned t_index = t_out * out_channels; + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + return 0; +} + +int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + const ConvLayers_Params* tparams= (ConvLayers_Params*) params; + unsigned vec_stride = 1, cols_scale = in_channels; + if (tparams->depthwise) { + vec_stride = in_channels; + out_channels = in_channels; + cols_scale = 1; + } + + // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding + float* temp_out = (float*)malloc(out_channels * sizeof(float)); + for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + unsigned t_index = t_out * out_channels; + + if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { + // Filter fully inside the input. Kept as the initial condition, since this is the most common one + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, kernel_size * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if ((t_in_start < padding) && (t_in_end >= padding)) { + // Filter partially entered the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W + (padding - t_in_start) * cols_scale, + input_signal, out_channels, (t_in_end - padding + 1) * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, (in_time + padding - t_in_start) * cols_scale, + kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); + memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely in the padding region + // The filter is either fully outside the input or has not yet entered the input + memset(output_signal + t_index, 0, out_channels * sizeof(float)); + } + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + return 0; +} + +int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, + const float* input_signal, unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, + const void* params, unsigned stride, unsigned activation) { + + unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; + // Calculate the number of time steps in one row for the first non-overlapping instance + while (num_steps_one_row < kernel_size) { + num_steps_one_row += stride; + num_iter++; + } + unsigned total_in_cols = num_steps_one_row * in_channels; + + const ConvLayers_Parallel_Params* tparams = (ConvLayers_Parallel_Params*) params; + // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Buffer to hold the output. For corner cases, this will be realtively big. + // But will be needed for the central condition (filter inside input). + // If there are not enough time steps to linearise into one row, then allocate only 1 time step + unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? + in_time / num_steps_one_row : 1; + float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here + for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; + t_in_start < padding && t_out < out_time; + t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_end < padding) { + // Filter outside the input region and in the padded region + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); + } + else { //(t_in_end >= padding) + // Filter partially entered the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W + (padding - t_in_start) * in_channels, + input_signal, out_channels, (t_in_end - padding + 1) * in_channels, + ncols, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + } + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied) + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria + // The MatVec will be used for the computation in-case the following block is skipped + if (in_time > ((num_steps_one_row << 1) + stride)) { + t_in_start -= padding; // remove the padding offset temporarily + t_in_end -= padding; // Used to keep track of the final processed index + for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); + iter++, t_in_start += stride, t_out++) { + unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; + memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); + if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { + // t_in_end is used to find the furthest time step was used in the MatMul calculation + // This value will be used for calculating the index for the final section of the processing + t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; + } + transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W, + in_rows, ncols, out_channels, total_in_cols, ncols, + temp_out, tparams->block_size); + // Copy all the data into the output + float* output_offset = (float*)output_signal + t_out * out_channels; + float* temp_offset = (float*)temp_out; + unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; + while (t_iter--) { + memcpy(output_offset, temp_offset, out_channels * sizeof(float)); + output_offset += offset_factor_for_out; + temp_offset += out_channels; + } + } + // Initialize the time iterators + // Use the stored value in t_in_end to calculate the iterators + t_in_start = t_in_end + padding; // Add the padding and stride offsets again + t_in_end = t_in_start + kernel_size - 1; + t_out = t_in_start / stride; + } + for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { + if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { + // Filter fully in the input but very close to the edges. + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped + // Incase the MatMul is skipped, this block will be used to compute the results + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, kernel_size * in_channels, + kernel_size * in_channels, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { + // Filter partially exited the input + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + offset_matVec_conv1d(tparams->W, + input_signal + (t_in_start - padding) * in_channels, + out_channels, (in_time + padding - t_in_start) * in_channels, + ncols, 1, 0, temp_out); + memcpy(output_signal + t_out * out_channels, + temp_out, out_channels * sizeof(float)); + } + else { + // Filter completely outside the input and in the padding region + memset(output_signal + t_out * out_channels, + 0, out_channels * sizeof(float)); + } + } + // Bias and activation + for (t_out = 0; t_out < out_time; t_out++) { + unsigned t_index = t_out * out_channels; + for (unsigned co = 0; co < out_channels; co++) { + // Post-Conv activation. More activation functions can be added should the necessity arise + switch (activation) { + case 1 : + output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 2 : + output_signal[t_index + co] = tanh(output_signal[t_index + co] + + tparams->B[co]); + break; + + case 3 : + output_signal[t_index + co] = relu(output_signal[t_index + co] + + tparams->B[co]); + break; + + default : + output_signal[t_index + co] += tparams->B[co]; + break; + } + } + } + free(temp_out); + return 0; +} + +int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal, + unsigned in_time, unsigned in_channels, + unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation) { + + // Iterate over the time steps and average them + float scale = 1.0/(float)kernel_size; // To avoid divisions + for (unsigned t_in = 0, t_out = 0; t_out < out_time; t_out++, t_in += stride) { + for (unsigned ci = 0; ci < in_channels; ci++) { + float sum = 0; + for (unsigned tf = 0; tf < kernel_size; tf++) { + if (((t_in + tf) < padding) || ((t_in + tf) >= (in_time + padding))) { + continue; + } + else { + sum += (input_signal[((tf + t_in) - padding) * in_channels + ci]); + } + } + switch (activation) { + case 1 : + output_signal[t_out * in_channels + ci] = sigmoid(sum * scale); + break; + + case 2 : + output_signal[t_out * in_channels + ci] = tanh(sum * scale); + break; + + case 3 : + output_signal[t_out * in_channels + ci] = relu(sum * scale); + break; + + default : + output_signal[t_out * in_channels + ci] = sum * scale; + break; + } + } + } + return 0; +} + +int batchnorm1d(float* output_signal, float* input_signal, + unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma , const float* const beta, + unsigned in_place, float eps) { + float* ret = in_place ? (float*)input_signal : (float*)output_signal; + + // Check for affine_config + // = 1 ; Use gamma, beta, mean and var + // = 2 ; Use only gamma and beta + // = 3 ; Use only mean and var + if (affine_config == 1) { + while (in_time--) { + float* gamma_offset = (float*)gamma; + float* beta_offset = (float*)beta; + float* mean_offset = (float*)mean; + float* var_offset = (float*)var; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + } + #endif + + while (channels--) { + *ret++ = (*gamma_offset++) * (((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps)) + (*beta_offset++); + } + } + } + else if (affine_config == 2) { + while (in_time--) { + float* gamma_offset = (float*)gamma; + float* beta_offset = (float*)beta; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + } + #endif + + while (channels--) { + *ret++ = ((*gamma_offset++) * (*input_signal++)) + (*beta_offset++); + } + } + } + else { + while (in_time--) { + float* mean_offset = (float*)mean; + float* var_offset = (float*)var; + unsigned channels = in_channels; + + #ifdef LOOP_UNROLL + unsigned len_unroll = channels >> 2; + channels %= 4; + while (len_unroll--) { + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + } + #endif + + while (channels--) { + *ret++ = ((*input_signal++) - (*mean_offset++)) / + sqrt((*var_offset++) + eps); + } + } + } + return 0; +} diff --git a/c_reference/src/dscnn.c b/c_reference/src/dscnn.c new file mode 100644 index 000000000..a304ff54f --- /dev/null +++ b/c_reference/src/dscnn.c @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "dscnn.h" +#include "conv1d.h" +#include "utils.h" + +int phon_pred_lr_cnn(float* output_signal, float* input_signal, + conv_layer cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned cnn_hidden, unsigned cnn_padding, unsigned cnn_kernel_size, + const void* cnn_params, unsigned cnn_stride, unsigned cnn_activation) { + + unsigned out_time = in_time - cnn_kernel_size + 2 * cnn_padding + 1; + if (in_place) { + // BatchNorm + batchnorm1d(0, input_signal, + in_time, in_channels, + mean, var, affine_config, gamma, beta, + in_place, 0.00001); + // CNN + cnn(output_signal, out_time, cnn_hidden, input_signal, + in_time, in_channels, cnn_padding, cnn_kernel_size, + cnn_params, cnn_stride, cnn_activation); + } + else { + // BatchNorm + float* norm_out = (float*)malloc(in_time * in_channels * sizeof(float)); + batchnorm1d(norm_out, input_signal, + in_time, in_channels, + mean, var, affine_config, gamma, beta, + in_place, 0.00001); + // CNN + cnn(output_signal, out_time, cnn_hidden, norm_out, + in_time, in_channels, cnn_padding, cnn_kernel_size, + cnn_params, cnn_stride, cnn_activation); + free(norm_out); + } + return 0; +} + +int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, + conv_layer point_cnn, unsigned in_time, unsigned in_channels, + const float* const mean, const float* const var, + unsigned affine_config, const float* const gamma, const float* const beta, unsigned in_place, + unsigned depth_cnn_padding, unsigned depth_cnn_kernel_size, + const void* depth_cnn_params, unsigned depth_cnn_stride, unsigned depth_cnn_activation, + unsigned point_cnn_hidden, unsigned point_cnn_padding, unsigned point_cnn_kernel_size, + const void* point_cnn_params, unsigned point_cnn_stride, unsigned point_cnn_activation, + unsigned pool_padding, unsigned pool_kernel_size, unsigned pool_stride, unsigned pool_activation) { + + // Activation + + float* act_out= (float*)malloc(in_time * (in_channels >> 1) * sizeof(float)); + semi_sigmoid_tanh(act_out, input_signal, in_time, in_channels); + + in_channels >>= 1; + float* depth_out; + unsigned out_time = in_time - depth_cnn_kernel_size + 2 * depth_cnn_padding + 1; + if (in_place) { + // Norm + batchnorm1d(0, act_out, + in_time, in_channels, + mean, var, + affine_config, gamma, beta, + in_place, 0.00001); + // Depth CNN + depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); + conv1d(depth_out, out_time, 0, act_out, + in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, + depth_cnn_params, depth_cnn_stride, depth_cnn_activation); + free(act_out); + } + else { + // Norm + float* norm_out = (float*)malloc(in_time * in_channels * sizeof(float)); + batchnorm1d(norm_out, act_out, + in_time, in_channels, + mean, var, + affine_config, gamma, beta, + in_place, 0.00001); + free(act_out); + // Depth CNN + depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); + conv1d(depth_out, out_time, 0, norm_out, + in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, + depth_cnn_params, depth_cnn_stride, depth_cnn_activation); + free(norm_out); + } + + // Point CNN + in_time = out_time; + out_time = in_time - point_cnn_kernel_size + 2 * point_cnn_padding + 1; + float* point_out = (float*)malloc(out_time * point_cnn_hidden * sizeof(float)); + point_cnn(point_out, out_time, point_cnn_hidden, depth_out, + in_time, in_channels, point_cnn_padding, point_cnn_kernel_size, + point_cnn_params, point_cnn_stride, point_cnn_activation); + free(depth_out); + + // Pool + in_time = out_time; + out_time = in_time - pool_kernel_size + 2 * pool_padding + 1; + avgpool1d(output_signal, out_time, point_out, + in_time, point_cnn_hidden, + pool_padding, pool_kernel_size, pool_stride, pool_activation); + free(point_out); + return 0; +} diff --git a/c_reference/src/rnn_bricked.c b/c_reference/src/rnn_bricked.c new file mode 100644 index 000000000..041ae8f05 --- /dev/null +++ b/c_reference/src/rnn_bricked.c @@ -0,0 +1,303 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include "rnn_bricked.h" +#include "utils.h" + +// Forward Pass +int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_first_brick) { + + // Buffers and params + const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; + + unsigned rnn_assign_offset = rnn_hidden, out_index = 0; + unsigned num_bricks = (in_time - window) / hop + 1; + // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden + // This function only processes the forward context + if (bi_direction) { + rnn_assign_offset <<= 1; + } + + // Compute W1 * W2 * X + float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); + float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); + float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + float* preComp = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + transposed_tiledMatMul(input_signal, tparams->W1, in_time, in_dims, + tparams->wRank, in_dims, in_dims, + tempLR, tparams->block_size_w_to_lr); + transposed_tiledMatMul(tempLR, tparams->W2, in_time, tparams->wRank, + rnn_hidden, tparams->wRank, tparams->wRank, + inputMulW, tparams->block_size_w_from_lr); + free(tempLR); + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch + // memset is used. Hence, malloc can be used here for matMul result initialization + tempLR = (float*)malloc(num_bricks * tparams->uRank * sizeof(float)); + for (unsigned t = 0; t < window; t++) { + // From higher dims to lower dims + memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); + transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, + tparams->uRank, rnn_hidden, rnn_hidden, + tempLR, tparams->block_size_u_to_lr); + // From lower dims to higher dims + // Add Wx with Uh + // The tiled MatMuls are codes such that they yield result += matA * matB + // Hence we use calloc and memset to equate the result to 0 + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input + float* preComp_offset = (float*)preComp; + for (unsigned n = 0; n < num_bricks; n++) { + float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden %= 4; + while (len_unroll--) { + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + } + #endif + + while (hidden--) { + *preComp_offset++ = *inputMulW_offset++; + } + } + transposed_tiledMatMul(tempLR, tparams->U2, num_bricks, tparams->uRank, + rnn_hidden, tparams->uRank, tparams->uRank, + preComp, tparams->block_size_u_from_lr); + + // Apply the gating + float* hiddenState_offset = (float*)hiddenState; + preComp_offset = (float*)preComp; + unsigned bricks = num_bricks; + while (bricks--) { + float* gateBias = (float*)tparams->Bg; + float* hiddenBias = (float*)tparams->Bh; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden = rnn_hidden % 4; + float gate, update; + while (len_unroll--) { + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + #endif + + while (hidden--) { + float gate = sigmoid((*preComp_offset) + (*gateBias++)); + float update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + } + // Sample first block if necessary + if (sample_first_brick) { + if (t % hop == 0) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState, rnn_hidden * sizeof(float)); + } + } + } + if (bi_direction) { + // If bi-directional then a gap would need to be left for the backward outputs + float* hiddenState_offset = hiddenState; + for (unsigned n = 0; n < num_bricks; n++) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState_offset, rnn_hidden * sizeof(float)); + hiddenState_offset += rnn_hidden; + } + } + else { + // If only forward is needed, the the whole block of memory can be copied without the loop + memcpy(output_signal + out_index * rnn_assign_offset, + hiddenState, num_bricks * rnn_hidden * sizeof(float)); + } + free(hiddenState); + free(inputMulW); + free(preComp); + free(tempLR); + return 0; +} + +// Backward Pass +int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, + float* input_signal, unsigned in_time, unsigned in_dims, + unsigned window, unsigned hop, const void* params, + unsigned bi_direction, unsigned sample_last_brick) { + + // Buffers and params + const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; + + unsigned rnn_assign_offset = rnn_hidden; + unsigned num_bricks = (in_time - window) / hop + 1; + unsigned out_index = in_time / hop; // = out_time - 1; + // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden + // This function only processes the forward context + if (bi_direction) { + rnn_assign_offset <<= 1; + } + + // Compute W1 * W2 * X + float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); + float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); + float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + float* preComp = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); + transposed_tiledMatMul(input_signal, tparams->W1, in_time, in_dims, + tparams->wRank, in_dims, in_dims, + tempLR, tparams->block_size_w_to_lr); + transposed_tiledMatMul(tempLR, tparams->W2, in_time, tparams->wRank, + rnn_hidden, tparams->wRank, tparams->wRank, + inputMulW, tparams->block_size_w_from_lr); + free(tempLR); + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch + tempLR = (float*)calloc(num_bricks * tparams->uRank, sizeof(float)); + for (int t = window - 1; t >= 0; t--) { + // From higher dims to lower dims + memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); + transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, + tparams->uRank, rnn_hidden, rnn_hidden, + tempLR, tparams->block_size_u_to_lr); + // From lower dims to higher dims + // Add Wx with Uh + // The tiled MatMuls are codes such that they yield result += matA * matB + // Hence we use calloc and memset to equate the result to 0 + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input + float* preComp_offset = (float*)preComp; + for (unsigned n = 0; n < num_bricks; n++) { + float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden %= 4; + while (len_unroll--) { + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + *preComp_offset++ = *inputMulW_offset++; + } + #endif + + while (hidden--) { + *preComp_offset++ = *inputMulW_offset++; + } + } + transposed_tiledMatMul(tempLR, tparams->U2, num_bricks, tparams->uRank, + rnn_hidden, tparams->uRank, tparams->uRank, + preComp, tparams->block_size_u_from_lr); + + // Apply the gating + float* hiddenState_offset = (float*)hiddenState; + preComp_offset = (float*)preComp; + unsigned bricks = num_bricks; + while (bricks--) { + float* gateBias = (float*)tparams->Bg; + float* hiddenBias = (float*)tparams->Bh; + unsigned hidden = rnn_hidden; + + #ifdef LOOP_UNROLL + unsigned len_unroll = hidden >> 2; + hidden = rnn_hidden % 4; + float gate, update; + while (len_unroll--) { + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + gate = sigmoid((*preComp_offset) + (*gateBias++)); + update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + #endif + + while (hidden--) { + float gate = sigmoid((*preComp_offset) + (*gateBias++)); + float update = tanh((*preComp_offset++) + (*hiddenBias++)); + *hiddenState_offset = gate * (*hiddenState_offset) + + (tparams->sigmoid_zeta * (1.0 - gate) + + tparams->sigmoid_nu) * update; + hiddenState_offset++; + } + } + // Sample first block if necessary + if (sample_last_brick) { + if ((window - 1 - t) % hop == 0) { + // Iterate over the output in reverse + memcpy(output_signal + (out_index--) * rnn_assign_offset, + hiddenState + (num_bricks - 1) * rnn_hidden, rnn_hidden * sizeof(float)); + } + } + } + // Since the all first (final in reverse) hiddenstates are calculated, we assign the whole block + out_index = 0; + if (bi_direction) { + // If bi-directional then a gap would need to be left for the backward outputs + float* hiddenState_offset = hiddenState; + for (unsigned n = 0; n < num_bricks; n++) { + memcpy(output_signal + (out_index++) * rnn_assign_offset, + hiddenState_offset, rnn_hidden * sizeof(float)); + hiddenState_offset += rnn_hidden; + } + } + else { + // If only forward is needed, the the whole block of memory can be copied without the loop + memcpy(output_signal + out_index * rnn_assign_offset, + hiddenState, num_bricks * rnn_hidden * sizeof(float)); + } + free(hiddenState); + free(inputMulW); + free(preComp); + free(tempLR); + return 0; +} diff --git a/c_reference/src/utils.c b/c_reference/src/utils.c index dc58a5fa0..0373d0c0b 100644 --- a/c_reference/src/utils.c +++ b/c_reference/src/utils.c @@ -71,6 +71,128 @@ void matVec(const float* const mat, const float* const vec, } } +void offset_matVec_conv1d(const float* mat, const float* vec, + unsigned nrows, unsigned ncols, + unsigned row_stride, unsigned vec_stride, + unsigned depthwise, float* ret) { + + while (nrows--) { + // For depthwise, the vec(input) pointer is updated + // Since each row of the mat corresponds to a separate channel index + float* vec_offset = depthwise ? (float*)vec++ : (float*)vec; + float* mat_offset = (float*)mat; + float sum = 0.0f; + unsigned cols = ncols; + + #ifdef LOOP_UNROLL + unsigned len_unroll = cols >> 2; + cols %= 4; // ncols % 4 + while (len_unroll--) { + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + } + #endif + + while (cols--) { + sum += (*mat_offset++) * (*vec_offset); + vec_offset += vec_stride; + } + *ret++ = sum; + mat += row_stride; + } +} + +void tiledMatMul_float(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_cols_B, + float* const ret, unsigned block_size) { + for (unsigned row = 0; row < nrows; row += block_size) { + unsigned row_block_size = (row + block_size < nrows) ? block_size : nrows - row; + for (unsigned col = 0; col < ncols; col += block_size) { + unsigned col_block_size = (col + block_size < ncols) ? block_size : ncols - col; + for (unsigned comm = 0; comm < ncommon; comm += block_size) { + unsigned comm_block_size = (comm + block_size < ncommon) ? block_size : ncommon - comm; + for (unsigned block_row = row; block_row < row + row_block_size; block_row++) { + float *ret_offset = (float *)ret + block_row * ncols + col; + for (unsigned block_col = col; block_col < col + col_block_size; block_col++) { + float sum = 0; + unsigned temp_block_size = comm_block_size; + const float *matA_offset = (const float*)matA + block_row * total_comm_A + comm; + const float *matB_offset = (const float*)matB + comm * total_cols_B + block_col; + + #ifdef LOOP_UNROLL + unsigned len_unroll = temp_block_size >> 2; + temp_block_size %= 4; // comm_block_size % 4 + while (len_unroll--) { + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + } + #endif + + while (temp_block_size--) { + sum += (*matA_offset++) * (*matB_offset); + matB_offset += ncols; + } + *ret_offset++ += sum; + } + } + } + } + } +} + +void transposed_tiledMatMul(const float* const matA, const float* const matB, + unsigned nrows, unsigned ncommon, unsigned ncols, + unsigned total_comm_A, unsigned total_comm_B, + float* const ret, unsigned block_size) { + for (unsigned row = 0; row < nrows; row += block_size) { + unsigned row_block_size = (row + block_size < nrows) ? block_size : nrows - row; + for (unsigned col = 0; col < ncols; col += block_size) { + unsigned col_block_size = (col + block_size < ncols) ? block_size : ncols - col; + for (unsigned comm = 0; comm < ncommon; comm += block_size) { + unsigned comm_block_size = (comm + block_size < ncommon) ? block_size : ncommon - comm; + for (unsigned block_row = row; block_row < row + row_block_size; block_row++) { + float *ret_offset = (float *)ret + block_row * ncols + col; + for (unsigned block_col = col; block_col < col + col_block_size; block_col++) { + float sum = 0; + unsigned temp_block_size = comm_block_size; + const float *matA_offset = (const float*)matA + block_row * total_comm_A + comm; + const float *matB_offset = (const float*)matB + block_col * total_comm_B + comm; + + #ifdef LOOP_UNROLL + unsigned len_unroll = temp_block_size >> 2; + temp_block_size %= 4; // comm_block_size % 4 + while (len_unroll--) { + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + sum += (*matA_offset++) * (*matB_offset++); + } + #endif + + while (temp_block_size--) { + sum += (*matA_offset++) * (*matB_offset++); + } + *ret_offset++ += sum; + } + } + } + } + } +} + void v_add(float scalar1, const float* const vec1, float scalar2, const float* const vec2, unsigned len, float* const ret) { @@ -120,3 +242,34 @@ void softmax(const float* const input, unsigned len, float* const ret) { for (unsigned i = 0; i < len; i++) ret[i] = expf(input[i] - offset); } + +void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, + unsigned in_time, unsigned in_channels) { + unsigned time_step = 0; // used to avoid index multiplication + while (in_time--) { + unsigned pivot = in_channels >> 1; + float* input_sigmoid_offset = (float*)input_signal + time_step; + float* input_tanh_offset = (float*)input_signal + time_step + pivot; + + #ifdef LOOP_UNROLL + unsigned len_unroll = pivot >> 2; + pivot %= 4; + while (len_unroll--) { + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + } + #endif + + while (pivot--) { + *output_signal++ = sigmoid(*input_sigmoid_offset++) * + tanh(*input_tanh_offset++); + } + time_step += in_channels; + } +} diff --git a/c_reference/tests/Makefile b/c_reference/tests/Makefile index 08f418286..4eb8c7d70 100644 --- a/c_reference/tests/Makefile +++ b/c_reference/tests/Makefile @@ -8,7 +8,11 @@ MODEL_DIR=../models SRC_DIR=../src IFLAGS = -I $(INCLUDE_DIR) -I $(MODEL_DIR) -all: test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse +all: test_fastgrnn_lr test_conv1d test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse test_rnn_bricked test_phoneme_det_cnn_rnn + +CONV1D_DIR=conv1d +test_conv1d: $(CONV1D_DIR)/test_conv1d.c $(SRC_DIR)/conv1d.o $(SRC_DIR)/utils.o + $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm FASTGRNN_DIR=fastgrnn test_fastgrnn_lr: $(FASTGRNN_DIR)/test_fastgrnn_lr.c $(SRC_DIR)/utils.o $(SRC_DIR)/fastgrnn.o $(SRC_DIR)/classifier.o @@ -38,10 +42,18 @@ test_quantized_face_detection_fast: $(FACE_DETECTION_DIR)/test_quantized_face_de test_quantized_face_detection_sparse: $(FACE_DETECTION_DIR)/test_quantized_face_detection_sparse.c $(SRC_DIR)/quantized_utils.o $(SRC_DIR)/quantized_fastgrnn.o $(SRC_DIR)/quantized_rnnpool.o $(SRC_DIR)/quantized_mbconv.o $(MODEL_DIR)/quantized_face_detection_sparse.o $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -Wno-unused-result -lm +RNNBRICKED_DIR=rnn_bricked +test_rnn_bricked: $(RNNBRICKED_DIR)/test_rnn_bricked.c $(SRC_DIR)/utils.o $(SRC_DIR)/rnn_bricked.o + $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm + +KWS_DIR=kws +test_phoneme_det_cnn_rnn: $(KWS_DIR)/test_phoneme_det_cnn_rnn.c $(SRC_DIR)/utils.o $(SRC_DIR)/conv1d.o $(SRC_DIR)/dscnn.o $(SRC_DIR)/rnn_bricked.o + $(CC) -o $@ $^ $(IFLAGS) $(CFLAGS) -lm + .PHONY: clean cleanest clean: - rm -f *.o *.gch test_fastgrnn_lr test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse + rm -f *.o *.gch test_fastgrnn_lr test_conv1d test_rnnpool test_quantized_utils test_quantized_fastgrnn test_quantized_rnnpool test_quantized_mbconv test_quantized_face_detection test_quantized_face_detection_fast test_quantized_face_detection_sparse test_rnn_bricked test_phoneme_det_cnn_rnn cleanest: clean rm *~ diff --git a/c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h b/c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h new file mode 100644 index 000000000..e9b3f68da --- /dev/null +++ b/c_reference/tests/conv1d/conv1d_depthwise/conv_param_depth.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d705c8b29a9eaf7255e15fb96314cc5b541d13e8a44494921fa0d00fbe46beee +size 39066 diff --git a/c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h b/c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h new file mode 100644 index 000000000..c936bb204 --- /dev/null +++ b/c_reference/tests/conv1d/conv1d_lr/conv_param_lr.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64e7cbd963bfe285df54cac1484b62560ddc8e9a0384392f033b92d3f3b3df1b +size 70398 diff --git a/c_reference/tests/conv1d/conv1d_regular/conv_param.h b/c_reference/tests/conv1d/conv1d_regular/conv_param.h new file mode 100644 index 000000000..6f2ca1edc --- /dev/null +++ b/c_reference/tests/conv1d/conv1d_regular/conv_param.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8de43d4e289ee507ee629a7a15adc4d914ab4f03f37eaf15bb1ecb8ccb97671c +size 108492 diff --git a/c_reference/tests/conv1d/test_conv1d.c b/c_reference/tests/conv1d/test_conv1d.c new file mode 100644 index 000000000..189b4f257 --- /dev/null +++ b/c_reference/tests/conv1d/test_conv1d.c @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include "conv1d.h" +#include "utils.h" + +#include "./conv1d_regular/conv_param.h" +#include "./conv1d_depthwise/conv_param_depth.h" +#include "./conv1d_lr/conv_param_lr.h" + +// Error Check +void errorCheck(float* pred, float* label, unsigned out_time, int out_features) { + float error = 0, denom = 0; + for (unsigned t = 0; t < out_time; t++) { + for (unsigned d = 0; d < out_features; d++) { + error += ((pred[t * out_features + d] - label[t * out_features + d]) + * (pred[t * out_features + d] - label[t * out_features + d])); + denom += label[t * out_features + d] * label[t * out_features + d]; + } + } + float avg_error = error / (out_time * out_features), rmse = error / denom; + printf("Agg Squared Error: %f ; MSE: %f ; RMSE: %f\n", error, avg_error, rmse); +} + +void conv1d_check() { + ConvLayers_Params conv_params = { + .W = CONV1D_CONV_WEIGHT, + .B = CONV1D_CONV_BIAS, + .depthwise = 0, + }; + + float* pred = (float*)malloc(CONV1D_OUT_TIME * CONV1D_OUT_FEATURES * sizeof(float)); + conv1d(pred, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES, CONV1D_INPUT, + CONV1D_IN_TIME, CONV1D_IN_FEATURES, CONV1D_PAD, CONV1D_FILT, + &conv_params, CONV1D_STRIDE, CONV1D_ACT); + + printf("Testing Regular Convolution\n"); + errorCheck(pred, CONV1D_OUTPUT, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES); + free(pred); +} + +void conv1d_parallel_check() { + ConvLayers_Parallel_Params conv_params = { + .W = CONV1D_CONV_WEIGHT, + .B = CONV1D_CONV_BIAS, + .block_size = 100, + }; + + float* pred = (float*)malloc(CONV1D_OUT_TIME * CONV1D_OUT_FEATURES * sizeof(float)); + conv1d_parallel(pred, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES, CONV1D_INPUT, + CONV1D_IN_TIME, CONV1D_IN_FEATURES, CONV1D_PAD, CONV1D_FILT, + &conv_params, CONV1D_STRIDE, CONV1D_ACT); + + printf("Testing Parallel Convolution\n"); + errorCheck(pred, CONV1D_OUTPUT, CONV1D_OUT_TIME, CONV1D_OUT_FEATURES); + free(pred); +} + +void conv1d_depth_check() { + ConvLayers_Params conv_params = { + .W = CONV1D_DEPTH_CONV_WEIGHT, + .B = CONV1D_DEPTH_CONV_BIAS, + .depthwise = 1, + }; + + float* pred = (float*)malloc(CONV1D_DEPTH_OUT_TIME * CONV1D_DEPTH_OUT_FEATURES + * sizeof(float)); + conv1d(pred, CONV1D_DEPTH_OUT_TIME, 0, CONV1D_DEPTH_INPUT, + CONV1D_DEPTH_IN_TIME, CONV1D_DEPTH_IN_FEATURES, CONV1D_DEPTH_PAD, CONV1D_DEPTH_FILT, + &conv_params, CONV1D_DEPTH_STRIDE, CONV1D_DEPTH_ACT); + + printf("Testing Depthwise Convolution\n"); + errorCheck(pred, CONV1D_DEPTH_OUTPUT, + CONV1D_DEPTH_OUT_TIME, CONV1D_DEPTH_OUT_FEATURES); + free(pred); +} + +void conv1d_lr_check() { + ConvLayers_LR_Params conv_params = { + .W1 = CONV1D_LR_CONV_W1, + .W2 = CONV1D_LR_CONV_W2, + .B = CONV1D_LR_CONV_BIAS, + .rank = CONV1D_LR_LOW_RANK + }; + + float* pred = (float*)malloc(CONV1D_LR_OUT_TIME + * CONV1D_LR_OUT_FEATURES * sizeof(float)); + conv1d_lr(pred, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES, CONV1D_LR_INPUT, + CONV1D_LR_IN_TIME, CONV1D_LR_IN_FEATURES, CONV1D_LR_PAD, CONV1D_LR_FILT, + &conv_params, CONV1D_LR_STRIDE, CONV1D_LR_ACT); + + printf("Testing Low-Rank Convolution\n"); + errorCheck(pred, CONV1D_LR_OUTPUT, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES); + free(pred); +} + +void conv1d_lr_parallel_check() { + ConvLayers_LR_Parallel_Params conv_params = { + .W1 = CONV1D_LR_CONV_W1, + .W2 = CONV1D_LR_CONV_W2, + .B = CONV1D_LR_CONV_BIAS, + .rank = CONV1D_LR_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + float* pred = (float*)malloc(CONV1D_LR_OUT_TIME + * CONV1D_LR_OUT_FEATURES * sizeof(float)); + conv1d_lr_parallel(pred, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES, CONV1D_LR_INPUT, + CONV1D_LR_IN_TIME, CONV1D_LR_IN_FEATURES, CONV1D_LR_PAD, CONV1D_LR_FILT, + &conv_params, CONV1D_LR_STRIDE, CONV1D_LR_ACT); + + printf("Testing Low-Rank Parallel Convolution\n"); + errorCheck(pred, CONV1D_LR_OUTPUT, CONV1D_LR_OUT_TIME, CONV1D_LR_OUT_FEATURES); + free(pred); +} + +int main() { + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif + conv1d_check(); + conv1d_parallel_check(); + conv1d_lr_check(); + conv1d_depth_check(); + conv1d_lr_parallel_check(); + return 0; +} diff --git a/c_reference/tests/kws/keyword_spotting_io_1.h b/c_reference/tests/kws/keyword_spotting_io_1.h new file mode 100644 index 000000000..18517f20e --- /dev/null +++ b/c_reference/tests/kws/keyword_spotting_io_1.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1faf461aaadd548a9c9c6a8b3b62552e1cc41d268bbb6a5f2b3abf5d9e0bc575 +size 326949 diff --git a/c_reference/tests/kws/keyword_spotting_io_2.h b/c_reference/tests/kws/keyword_spotting_io_2.h new file mode 100644 index 000000000..293d4b379 --- /dev/null +++ b/c_reference/tests/kws/keyword_spotting_io_2.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:804e9c5d4053f61b993486957a8e55c2e9577f52f6297a180b0430da73290709 +size 306578 diff --git a/c_reference/tests/kws/keyword_spotting_io_3.h b/c_reference/tests/kws/keyword_spotting_io_3.h new file mode 100644 index 000000000..f6efbb000 --- /dev/null +++ b/c_reference/tests/kws/keyword_spotting_io_3.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccf1ba4ea4c53597f0c60f9a60e0031ceb45446d0c314944a1a9b6f286f1a921 +size 308150 diff --git a/c_reference/tests/kws/postcnn_params.h b/c_reference/tests/kws/postcnn_params.h new file mode 100644 index 000000000..9da921d22 --- /dev/null +++ b/c_reference/tests/kws/postcnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0adca3c3658d5860193ae73f62468742dbc2c1f8b6508abdb2efa9311111165d +size 1374545 diff --git a/c_reference/tests/kws/precnn_params.h b/c_reference/tests/kws/precnn_params.h new file mode 100644 index 000000000..fb1539736 --- /dev/null +++ b/c_reference/tests/kws/precnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79383c92269713907e1247ec39232f8b98e95850637ec5a1cce2a310ebf3e469 +size 520803 diff --git a/c_reference/tests/kws/rnn_params.h b/c_reference/tests/kws/rnn_params.h new file mode 100644 index 000000000..72a581918 --- /dev/null +++ b/c_reference/tests/kws/rnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77b63ebcde61cdec096e0be6d83d60c679c63a7608047317672c10d3e478e5fe +size 1302881 diff --git a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c new file mode 100644 index 000000000..ff204c8ef --- /dev/null +++ b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include "conv1d.h" +#include "dscnn.h" +#include "utils.h" +#include "rnn_bricked.h" + +#include "keyword_spotting_io_2.h" +#include "precnn_params.h" +#include "rnn_params.h" +#include "postcnn_params.h" + +// Check number of output time-steps with the number of label time-steps +int checkTime(unsigned out_time) { + if (out_time != KWS_OUT_TIME) { + printf("Error, estimated and actual ouput time-steps mismatch"); + return 1; + } + return 0; +} +// Error Check +void checkError(float* pred, float* label) { + float error = 0, denom = 0; + for (unsigned t = 0; t < KWS_OUT_TIME; t++) { + for (unsigned d = 0; d < POST_CNN_OUT_FEATURES; d++) { + error += ((pred[t * POST_CNN_OUT_FEATURES + d] + - label[t * POST_CNN_OUT_FEATURES + d]) + * (pred[t * POST_CNN_OUT_FEATURES + d] + - label[t * POST_CNN_OUT_FEATURES + d])); + denom += label[t * POST_CNN_OUT_FEATURES + d] + * label[t * POST_CNN_OUT_FEATURES + d]; + } + } + printf("Full Network\n"); + printf("Agg Squared Error : %f\n", error); + printf("MSE : %f\n", error / (KWS_OUT_TIME*POST_CNN_OUT_FEATURES)); + printf("RMSE : %f\n", error / denom); +} + +/* CNN-RNN based Phoneme Detection Model + + The phoneme detection model used consists of 6 blocks. + 1st block is a CNN, where kernel size is 5 and regular tanh activation + 2nd block is an RNN, which has a specified forward and a backward context running at a stride/hop of 3. + Hence it reduces the sequence length by a factor of 3. + Rest of the blocks(3rd, 4th, 5th and 6th) are a combination of CNNs + Each of the final 4 blocks consist of a depth cnn (kernel size of 5) and a point cnn (kernel size of 1) + + Input to the architecture is of the form (seq_len, feature_dim) where feature dim refers to n_mels (number of mel features/number of features from the featurizer). + Output is of the form (seq_len/3, 41) where 41 is the number of phonemes over which the classification is performed. + Phonemes are predicted for every 3rd time frame, operating under the assumption that they don't vary faster than that. + + NOTE: Before deployment for real-time streaming applications, we would need to make minor modification + These changes are subject to the input specs i.e fixing input buffer time steps, number of features from the deployed featurizer, method of reading the input into a buffer +*/ +void phoneme_prediction(float* mem_buf) { + ConvLayers_LR_Parallel_Params conv_params = { + .W1 = CNN1_W1, + .W2 = CNN1_W2, + .B = CNN1_BIAS, + .rank = PRE_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_2 = { + .W = CNN2_DEPTH_W, + .B = CNN2_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_2 = { + .W1 = CNN2_POINT_W1, + .W2 = CNN2_POINT_W2, + .B = CNN2_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_3 = { + .W = CNN3_DEPTH_W, + .B = CNN3_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_3 = { + .W1 = CNN3_POINT_W1, + .W2 = CNN3_POINT_W2, + .B = CNN3_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_4 = { + .W = CNN4_DEPTH_W, + .B = CNN4_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_4 = { + .W1 = CNN4_POINT_W1, + .W2 = CNN4_POINT_W2, + .B = CNN4_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + ConvLayers_Params depth_param_5 = { + .W = CNN5_DEPTH_W, + .B = CNN5_DEPTH_BIAS, + .depthwise = 1, + }; + + ConvLayers_LR_Parallel_Params point_param_5 = { + .W1 = CNN5_POINT_W1, + .W2 = CNN5_POINT_W2, + .B = CNN5_POINT_BIAS, + .rank = POST_CNN_LOW_RANK, + .block_size_to_lr = 100, + .block_size_from_lr = 100, + }; + + BrickedFastGRNN_LR_Params bwd_RNN_params = { + .W1 = B_W1, + .W2 = B_W2, + .wRank = RNN_LOW_RANK, + .U1 = B_U1, + .U2 = B_U2, + .uRank = RNN_LOW_RANK, + .Bg = B_BIAS_GATE, + .Bh = B_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(B_ZETA), + .sigmoid_nu = sigmoid(B_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + BrickedFastGRNN_LR_Params fwd_RNN_params = { + .W1 = F_W1, + .W2 = F_W2, + .wRank = RNN_LOW_RANK, + .U1 = F_U1, + .U2 = F_U2, + .uRank = RNN_LOW_RANK, + .Bg = F_BIAS_GATE, + .Bh = F_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(F_ZETA), + .sigmoid_nu = sigmoid(F_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + unsigned in_time, out_time; + + /* Pre-CNN */ + in_time = KWS_IN_TIME; + out_time = in_time - PRE_CNN_FILT + (PRE_CNN_FILT_PAD << 1) + 1; + float* cnn1_out = (float*)malloc(out_time * PRE_CNN_OUT_FEATURES * sizeof(float)); + // Since batchnorm1d is the first layer and in-place will alter the input. + // Use the in-place computation only if the input can be discarded/altered. Else avoid in-place computation for this layer + phon_pred_lr_cnn(cnn1_out, mem_buf, + conv1d_lr_parallel, in_time, PRE_CNN_IN_FEATURES, + 0, 0, PRE_CNN_BNORM_AFFINE, CNN1_SCALE, CNN1_OFFSET, PRE_CNN_BNORM_INPLACE, + PRE_CNN_OUT_FEATURES, PRE_CNN_FILT_PAD, PRE_CNN_FILT, + &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // regular tanh activation + + batchnorm1d(0, cnn1_out, in_time, RNN_IN_FEATURES, + 0, 0, RNN_BNORM_AFFINE, RNN_SCALE, RNN_OFFSET, 1, 0.00001); + + /* Bricked Bi-FastGRNN Block */ + out_time = in_time/RNN_HOP + 1; + float* rnn_out = (float*)malloc(out_time * RNN_OUT_FEATURES * sizeof(float)); + forward_bricked_fastgrnn_lr(rnn_out, RNN_OUT_FEATURES >> 1, cnn1_out, + in_time, RNN_IN_FEATURES, RNN_FWD_WINDOW, RNN_HOP, + &fwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_FIRST_BRICK); + + backward_bricked_fastgrnn_lr(rnn_out + (RNN_OUT_FEATURES >> 1), + RNN_OUT_FEATURES >> 1, cnn1_out, + in_time, RNN_IN_FEATURES, RNN_BWD_WINDOW, RNN_HOP, + &bwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_LAST_BRICK); + free(cnn1_out); + + /* Post-CNN */ + // Since all inputs to the subsequent layers are temporary, in-place batchnorm1d can be used without any input(initial buffer)/output(final layer) data alteration/corruption + // CNN2 + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* cnn2_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(cnn2_out, rnn_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN2_SCALE, CNN2_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_2, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_INTER_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_2, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(rnn_out); + + // CNN3 + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* cnn3_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(cnn3_out, cnn2_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN3_SCALE, CNN3_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_3, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_INTER_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_3, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(cnn2_out); + + // CNN4 + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* cnn4_out = (float*)malloc(out_time * POST_CNN_INTER_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(cnn4_out, cnn3_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN4_SCALE, CNN4_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_4, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_INTER_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_4, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(cnn3_out); + + // CNN5 + in_time = out_time; + out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; + out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; + float* pred = (float*)malloc(out_time * POST_CNN_OUT_FEATURES * sizeof(float)); + phon_pred_depth_point_lr_cnn(pred, cnn4_out, + conv1d_lr_parallel, in_time, POST_CNN_INTER_FEATURES, + 0, 0, POST_CNN_BNORM_AFFINE, CNN5_SCALE, CNN5_OFFSET, POST_CNN_BNORM_INPLACE, + POST_CNN_DEPTH_PAD, POST_CNN_DEPTH_FILT, + &depth_param_5, POST_CNN_DEPTH_STRIDE, POST_CNN_DEPTH_ACT, + POST_CNN_OUT_FEATURES, POST_CNN_POINT_PAD, POST_CNN_POINT_FILT, + &point_param_5, POST_CNN_POINT_STRIDE, POST_CNN_POINT_ACT, + POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); + free(cnn4_out); + + /* Output Time and Prediction Check. Created for Debugging */ + if (checkTime(out_time)) + return; + else + checkError(pred, OUTPUT); + free(pred); +} + +int main() { + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif + clock_t begin = clock(); + phoneme_prediction(INPUT); + clock_t end = clock(); + double time_spent = (float)(end - begin) / CLOCKS_PER_SEC; + printf("Time elapsed is %f seconds\n", time_spent); + return 0; +} diff --git a/c_reference/tests/rnn_bricked/rnn_bricked_io.h b/c_reference/tests/rnn_bricked/rnn_bricked_io.h new file mode 100644 index 000000000..a6d90e301 --- /dev/null +++ b/c_reference/tests/rnn_bricked/rnn_bricked_io.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0fc3181b35c0cfa858a5ff6415f4ad793915b2fce5c792e741a8f8e04b095349 +size 1908996 diff --git a/c_reference/tests/rnn_bricked/rnn_params.h b/c_reference/tests/rnn_bricked/rnn_params.h new file mode 100644 index 000000000..17060301a --- /dev/null +++ b/c_reference/tests/rnn_bricked/rnn_params.h @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f062722fe7b91fbb5f78af631d4c594d2b2af61b8dfdd266b799d39c133afa2 +size 1293672 diff --git a/c_reference/tests/rnn_bricked/test_rnn_bricked.c b/c_reference/tests/rnn_bricked/test_rnn_bricked.c new file mode 100644 index 000000000..701d73af4 --- /dev/null +++ b/c_reference/tests/rnn_bricked/test_rnn_bricked.c @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include "rnn_bricked.h" +#include "utils.h" + +#include "rnn_params.h" +#include "rnn_bricked_io.h" + +int main() { + + BrickedFastGRNN_LR_Params bwd_RNN_params = { + .W1 = B_W1, + .W2 = B_W2, + .wRank = RNN_LOW_RANK, + .U1 = B_U1, + .U2 = B_U2, + .uRank = RNN_LOW_RANK, + .Bg = B_BIAS_GATE, + .Bh = B_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(B_ZETA), + .sigmoid_nu = sigmoid(B_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + BrickedFastGRNN_LR_Params fwd_RNN_params = { + .W1 = F_W1, + .W2 = F_W2, + .wRank = RNN_LOW_RANK, + .U1 = F_U1, + .U2 = F_U2, + .uRank = RNN_LOW_RANK, + .Bg = F_BIAS_GATE, + .Bh = F_BIAS_UPDATE, + .sigmoid_zeta = sigmoid(F_ZETA), + .sigmoid_nu = sigmoid(F_NU), + .block_size_u_from_lr = 100, + .block_size_u_to_lr = 100, + .block_size_w_from_lr = 100, + .block_size_w_to_lr = 100, + }; + + float* pred = (float*)malloc(RNN_OUT_TIME * RNN_OUT_FEATURES * sizeof(float)); + + forward_bricked_fastgrnn_lr(pred, RNN_OUT_FEATURES >> 1, INPUT, + RNN_IN_TIME, RNN_IN_FEATURES, FWD_WINDOW, HOP, + &fwd_RNN_params, 1, 1); + + backward_bricked_fastgrnn_lr(pred + (RNN_OUT_FEATURES >> 1), RNN_OUT_FEATURES >> 1, INPUT, + RNN_IN_TIME, RNN_IN_FEATURES, BWD_WINDOW, HOP, + &bwd_RNN_params, 1, 1); + + float error = 0; + float denom = 0; + for (int t = 0; t < RNN_OUT_TIME; t++) { + for (int d = 0; d < RNN_OUT_FEATURES; d++) { + error += ((pred[t * RNN_OUT_FEATURES + d] - OUTPUT[t * RNN_OUT_FEATURES + d]) + * (pred[t * RNN_OUT_FEATURES + d] - OUTPUT[t * RNN_OUT_FEATURES + d])); + denom += OUTPUT[t * RNN_OUT_FEATURES + d] * OUTPUT[t * RNN_OUT_FEATURES + d]; + } + } + float avg_error = error / (RNN_OUT_TIME * RNN_OUT_FEATURES); + float rmse = error / denom; + + #ifdef LOOP_UNROLL + printf("Loop Unrolling Active\n"); + #endif + printf("Testing Bricked RNNs Bi-Directional\n"); + printf("Agg Squared Error: %f ; MSE: %f ; RMSE: %f\n", error, avg_error, rmse); + free(pred); + return 0; +} From 43b1786a3110baf1ee49b78638b6c4ab3d62172b Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Tue, 12 Oct 2021 13:55:16 -0700 Subject: [PATCH 2/4] Add punctuations and explain rmse --- c_reference/include/conv1d.h | 56 ++--- c_reference/include/rnn_bricked.h | 38 ++-- c_reference/src/conv1d.c | 198 +++++++++--------- c_reference/src/dscnn.c | 23 +- c_reference/src/rnn_bricked.c | 70 +++---- c_reference/src/utils.c | 10 +- c_reference/tests/conv1d/test_conv1d.c | 2 + .../tests/kws/test_phoneme_det_cnn_rnn.c | 30 +-- .../tests/rnn_bricked/test_rnn_bricked.c | 2 + 9 files changed, 217 insertions(+), 212 deletions(-) diff --git a/c_reference/include/conv1d.h b/c_reference/include/conv1d.h index a7ed49315..2547d8a03 100644 --- a/c_reference/include/conv1d.h +++ b/c_reference/include/conv1d.h @@ -6,31 +6,31 @@ /* All the matrices/tensors are stored in the row major format - NOTES for the conv layers + NOTES for the conv layers. -> The conv1d & conv1d_lr layers work for all cases and can be used unconstrained. - There are no hard constraints for the parallel version, but a few points regarding its optimal usage are given below --> Dilation = 1 (no dilation) for all cases --> For the non-depthwise cases, store the matrices as described below. Permutation might be necessary --> The low-rank decomposition cannot be applied to the depthwise weight matrices. This is due to the out_channels/in_channels = 0 constarint imposed by the depthwise convolution. - For full-rank this is satisfied since out_channels = in_channels - But, when the matrix is decomposed, the constarint is violated (since rank < out_channels ; rank is not divisible by in_channels) - Hence due to the decomposition being theoretically impossible, we have not provided the support - However we suggest a less-efficient alternative => First pre-compute the weights W = W2 * W1 and then use a regular conv --> For the parallel cases, the non-overlapping cases of the convolution are computed parallelly using MatMul (since the blocked MatMul is faster) - This howver is only valid for when the filter is fully in the input. There would be no-overlapping for the edge cases - Hence the MatVec code(regular code) is used to calculate these cases + There are no hard constraints for the parallel version, but a few points regarding its optimal usage are given below. +-> Dilation = 1 (no dilation) for all cases. +-> For the non-depthwise cases, store the matrices as described below. Permutation might be necessary. +-> The low-rank decomposition cannot be applied to the depthwise weight matrices. This is due to the out_channels/in_channels = 0 constarint imposed by the depthwise convolution. + For full-rank this is satisfied since out_channels = in_channels. + But, when the matrix is decomposed, the constarint is violated (since rank < out_channels ; rank is not divisible by in_channels). + Hence due to the decomposition being theoretically impossible, we have not provided the support. + However we suggest a less-efficient alternative => First pre-compute the weights W = W2 * W1 and then use a regular conv. +-> For the parallel cases, the non-overlapping cases of the convolution are computed parallelly using MatMul (since the blocked MatMul is faster). + This howver is only valid for when the filter is fully in the input. There would be no-overlapping for the edge cases. + Hence the MatVec code(regular code) is used to calculate these cases. - Important points regarding parallel versions --> Due to the above reason, the parallel layers is only recommended for large in_time inputs - This should typically be for in_time (without the padding) > 2 * num_steps_one_row + stride. Else there would not be enough time-steps to efficiently parallelise - We need at least 2 rows for a good a MatMul performace. In the worst case the starting time step would be (stride - 1). Hence we choose 2 * num_steps_one_row + stride as the threshold - For the short input cases, the code will skip the MatMul computation and use MatVec instead (but the MatMul-variable computation overhead would remain) - For such cases, the MatVec code (conv1d and conv1d_lr) would work more efficiently due to the lower RAM usage and lack of any major overheads --> There is no support for depthwise for conv1d_parallel - The regular convolution acts on all the channels while the depthwise acts only on one channel at a time - This results in a non-contiguos memory access. MatMul would need to process multiple such time-steps, while the MatVec would only need to process one - Hence, the MatVec would be able to enter the next channel earlier and would work much faster - While the MatMul would have cache misses (when dealing with the small chache size of edge devices) + Important points regarding parallel versions. +-> Due to the above reason, the parallel layers is only recommended for large in_time inputs. + This should typically be for in_time (without the padding) > 2 * num_steps_one_row + stride. Else there would not be enough time-steps to efficiently parallelise. + We need at least 2 rows for a good a MatMul performace. In the worst case the starting time step would be (stride - 1). Hence we choose 2 * num_steps_one_row + stride as the threshold. + For the short input cases, the code will skip the MatMul computation and use MatVec instead (but the MatMul-variable computation overhead would remain). + For such cases, the MatVec code (conv1d and conv1d_lr) would work more efficiently due to the lower RAM usage and lack of any major overheads. +-> There is no support for depthwise for conv1d_parallel. + The regular convolution acts on all the channels while the depthwise acts only on one channel at a time. + This results in a non-contiguos memory access. MatMul would need to process multiple such time-steps, while the MatVec would only need to process one. + Hence, the MatVec would be able to enter the next channel earlier and would work much faster. + While the MatMul would have cache misses (when dealing with the small chache size of edge devices). */ /** @@ -54,7 +54,7 @@ typedef struct ConvLayers_Params { * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed. + * @param[in] padding padding applied to the input before the conv is performed * Note: padding is applied to both the starting and ending of the input, along the time axis * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter @@ -91,7 +91,7 @@ typedef struct ConvLayers_Parallel_Params { * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed. + * @param[in] padding padding applied to the input before the conv is performed * Note: padding is applied to both the starting and ending of the input, along the time axis * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter @@ -131,7 +131,7 @@ typedef struct ConvLayers_LR_Params { * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed. + * @param[in] padding padding applied to the input before the conv is performed * Note: padding is applied to both the starting and ending of the input, along the time axis * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter @@ -175,7 +175,7 @@ typedef struct ConvLayers_LR_Parallel_Params { * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed. + * @param[in] padding padding applied to the input before the conv is performed * Note: padding is applied to both the starting and ending of the input, along the time axis * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the conv filter @@ -201,7 +201,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha * @param[in] input_signal pointer to the input signal. size = in_time * in_channels * @param[in] in_time number of time steps in the input * @param[in] in_channels number of input channels. The output will have the same number of channels - * @param[in] padding padding applied to the input before the conv is performed. + * @param[in] padding padding applied to the input before the conv is performed * Note: padding is applied to both the starting and ending of the input, along the time axis * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) * @param[in] kernel_size kernel size of the pool filter diff --git a/c_reference/include/rnn_bricked.h b/c_reference/include/rnn_bricked.h index adc910d42..a1f4bb696 100644 --- a/c_reference/include/rnn_bricked.h +++ b/c_reference/include/rnn_bricked.h @@ -4,30 +4,30 @@ #ifndef __RNN_BRICKED_H__ #define __RNN_BRICKED_H__ -/* All the matrices are stored in the row major format +/* All the matrices are stored in the row major format. - NOTES for using the layers --> Single-directional Computation - While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints - 1) in_time % hop = 0 - 2) fwd_window % hop = 0 and bwd_window % hop = 0 + NOTES for using the layers. +-> Single-directional Computation. + While using the bricked fastgrnn layers, the user needs to adhered to the two following constraints. + 1) in_time % hop = 0. + 2) fwd_window % hop = 0 and bwd_window % hop = 0. - Violation of the above two constraints (1 & 2), will cause segmentation faults - The layers first compute all the Wx steps and then compute Uh for all the windows parallelly - Hence, the user needs to adhered to the constraints 1 & 2 + Violation of the above two constraints (1 & 2), will cause segmentation faults. + The layers first compute all the Wx steps and then compute Uh for all the windows parallelly. + Hence, the user needs to adhered to the constraints 1 & 2. --> Bi-directional Computation - For bi-directional cases, there are 2 additionally constraints that would need to be followed - A) sample_first_brick and sample_last_brick = 1 - B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call +-> Bi-directional Computation. + For bi-directional cases, there are 2 additionally constraints that would need to be followed. + A) sample_first_brick and sample_last_brick = 1. + B) An offset of rnn_hidden would need to be given to the output_signal pointer during the backward function call. Each function will only process its given context(forward/backward). The other context will need to be called separately. - E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...) - 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...) + E.g : 1st step -> forward(output, ..., input, ..., bi-direction=1, ...). + 2nd step -> backward(output + rnn_hidden, ..., input, ..., bi-direction=1, ...). - The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used - Violating the conditions would cause index mis-matches or data corruption - If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result - If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only) + The two extra constraints (A & B) are only for bi-directional cases and can be ignored if only forward (or only backward) is used. + Violating the conditions would cause index mis-matches or data corruption. + If the first (last) brick is not sampled, the first few (last few) time steps would be missing in the forward (backward) result . + If the offset is not passed during the backward function call, the backward pass will overwrite the forward result (bi-directional case only). */ /** diff --git a/c_reference/src/conv1d.c b/c_reference/src/conv1d.c index 2ab5b7f30..f918fdae0 100644 --- a/c_reference/src/conv1d.c +++ b/c_reference/src/conv1d.c @@ -14,62 +14,62 @@ int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, const ConvLayers_LR_Params* tparams= (ConvLayers_LR_Params*) params; - // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding. unsigned rank = tparams->rank; - // Buffer for W2 out + // Buffer for W2 out. float* temp_rank_out = (float*)malloc(rank * sizeof(float)); - // Buffer for W1 out + // Buffer for W1 out. float* temp_out = (float*)malloc(out_channels * sizeof(float)); for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { unsigned t_index = t_out * out_channels; if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { - // Filter fully inside the input. Kept as the initial condition, since this is the most common one + // Filter fully inside the input. Kept as the initial condition, since this is the most common one. offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, rank, kernel_size * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); - // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, rank, rank, 1, 0, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if ((t_in_start < padding) && (t_in_end >= padding)) { - // Filter partially entered the input + // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, input_signal, rank, (t_in_end - padding + 1) * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); - // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, rank, rank, 1, 0, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { - // Filter partially exited the input + // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, rank, (in_time + padding - t_in_start) * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); - // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, rank, rank, 1, 0, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else { - // Filter completely in the padding region - // The filter is either fully outside the input or has not yet entered the input + // Filter completely in the padding region. + // The filter is either fully outside the input or has not yet entered the input. memset(output_signal + t_index, 0, out_channels * sizeof(float)); } for (unsigned co = 0; co < out_channels; co++) { - // Post-Conv activation. More activation functions can be added should the necessity arise + // Post-Conv activation. More activation functions can be added should the necessity arise. switch (activation) { case 1 : output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + @@ -103,7 +103,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha const void* params, unsigned stride, unsigned activation) { unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; - // Calculate the number of time steps in one row for the first non-overlapping instance + // Calculate the number of time steps in one row for the first non-overlapping instance. while (num_steps_one_row < kernel_size) { num_steps_one_row += stride; num_iter++; @@ -111,10 +111,10 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha unsigned total_in_cols = num_steps_one_row * in_channels; const ConvLayers_LR_Parallel_Params* tparams = (ConvLayers_LR_Parallel_Params*) params; - // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding. // Buffer to hold the output. For corner cases, this will be realtively big. // But will be needed for the central condition (filter inside input). - // If there are not enough time steps to linearise into one row, then allocate only 1 time step + // If there are not enough time steps to linearise into one row, then allocate only 1 time step. unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? in_time / num_steps_one_row : 1; unsigned rank = tparams->rank; @@ -123,48 +123,48 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha // Buffer for W1 out float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); - unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here. for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; t_in_start < padding && t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { if (t_in_end < padding) { - // Filter outside the input region and in the padded region + // Filter outside the input region and in the padded region. memset(output_signal + t_out * out_channels, 0, out_channels * sizeof(float)); } - else { //(t_in_end >= padding) - // Filter partially entered the input - // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + else { //(t_in_end >= padding). + // Filter partially entered the input. + // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, input_signal, rank, (t_in_end - padding + 1) * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); - // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, rank, rank, 1, 0, temp_out); memcpy(output_signal + t_out * out_channels, temp_out, out_channels * sizeof(float)); } } - // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases - // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region - // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row - // Using the above logic, we can convert the MatVec opeartion into a MatMul operation - // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied) - // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria - // The MatVec will be used for the computation in-case the following block is skipped + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases. + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region. + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row. + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation. + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied). + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria. + // The MatVec will be used for the computation in-case the following block is skipped. if (in_time > ((num_steps_one_row << 1) + stride)) { - t_in_start -= padding; // remove the padding offset temporarily - t_in_end -= padding; // Used to keep track of the final processed index + t_in_start -= padding; // remove the padding offset temporarily. + t_in_end -= padding; // Used to keep track of the final processed index. for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); iter++, t_in_start += stride, t_out++) { unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; memset(temp_rank_out, 0, buffer_steps * rank * sizeof(float)); memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { - // t_in_end is used to find the furthest time step was used in the MatMul calculation - // This value will be used for calculating the index for the final section of the processing + // t_in_end is used to find the furthest time step was used in the MatMul calculation. + // This value will be used for calculating the index for the final section of the processing. t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; } transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W2, @@ -173,7 +173,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha transposed_tiledMatMul(temp_rank_out , tparams->W1, in_rows, rank, out_channels, rank, rank, temp_out, tparams->block_size_from_lr); - // Copy all the data into the output + // Copy all the data into the output. float* output_offset = (float*)output_signal + t_out * out_channels; float* temp_offset = (float*)temp_out; unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; @@ -183,44 +183,44 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha temp_offset += out_channels; } } - // Initialize the time iterators - // Use the stored value in t_in_end to calculate the iterators - t_in_start = t_in_end + padding; // Add the padding and stride offsets again + // Initialize the time iterators. + // Use the stored value in t_in_end to calculate the iterators. + t_in_start = t_in_end + padding; // Add the padding and stride offsets again. t_in_end = t_in_start + kernel_size - 1; t_out = t_in_start / stride; } for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { // Filter fully in the input but very close to the edges. - // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped - // Incase the MatMul is skipped, this block will be used to compute the results + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped. + // Incase the MatMul is skipped, this block will be used to compute the results. offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, rank, kernel_size * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); - // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, rank, rank, 1, 0, temp_out); memcpy(output_signal + t_out * out_channels, temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { - // Filter partially exited the input + // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, rank, (in_time + padding - t_in_start) * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); - // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling) + // row_stride = ncols, vec_stride = 1, depthwise = 0. Hence the call is identical to a regular MatVec (without scaling). offset_matVec_conv1d(tparams->W1, temp_rank_out, out_channels, rank, rank, 1, 0, temp_out); memcpy(output_signal + t_out * out_channels, temp_out, out_channels * sizeof(float)); } else { - // Filter completely outside the input and in the padding region + // Filter completely outside the input and in the padding region. memset(output_signal + t_out * out_channels, 0, out_channels * sizeof(float)); } @@ -229,7 +229,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha for (t_out = 0; t_out < out_time; t_out++) { unsigned t_index = t_out * out_channels; for (unsigned co = 0; co < out_channels; co++) { - // Post-Conv activation. More activation functions can be added should the necessity arise + // Post-Conv activation. More activation functions can be added should the necessity arise. switch (activation) { case 1 : output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + @@ -269,14 +269,14 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, cols_scale = 1; } - // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding. float* temp_out = (float*)malloc(out_channels * sizeof(float)); for (unsigned t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { unsigned t_index = t_out * out_channels; if ((t_in_start >= padding) && (t_in_end < (in_time + padding))) { - // Filter fully inside the input. Kept as the initial condition, since this is the most common one + // Filter fully inside the input. Kept as the initial condition, since this is the most common one. offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, out_channels, kernel_size * cols_scale, @@ -284,20 +284,20 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if ((t_in_start < padding) && (t_in_end >= padding)) { - // Filter partially entered the input + // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W + (padding - t_in_start) * cols_scale, input_signal, out_channels, (t_in_end - padding + 1) * cols_scale, kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { - // Filter partially exited the input + // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, out_channels, (in_time + padding - t_in_start) * cols_scale, @@ -305,12 +305,12 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, memcpy(output_signal + t_index, temp_out, out_channels * sizeof(float)); } else { - // Filter completely in the padding region - // The filter is either fully outside the input or has not yet entered the input + // Filter completely in the padding region. + // The filter is either fully outside the input or has not yet entered the input. memset(output_signal + t_index, 0, out_channels * sizeof(float)); } for (unsigned co = 0; co < out_channels; co++) { - // Post-Conv activation. More activation functions can be added should the necessity arise + // Post-Conv activation. More activation functions can be added should the necessity arise. switch (activation) { case 1 : output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + @@ -343,7 +343,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe const void* params, unsigned stride, unsigned activation) { unsigned ncols = kernel_size * in_channels, num_iter = 0, num_steps_one_row = 0; - // Calculate the number of time steps in one row for the first non-overlapping instance + // Calculate the number of time steps in one row for the first non-overlapping instance. while (num_steps_one_row < kernel_size) { num_steps_one_row += stride; num_iter++; @@ -351,27 +351,27 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe unsigned total_in_cols = num_steps_one_row * in_channels; const ConvLayers_Parallel_Params* tparams = (ConvLayers_Parallel_Params*) params; - // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding + // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding. // Buffer to hold the output. For corner cases, this will be realtively big. // But will be needed for the central condition (filter inside input). - // If there are not enough time steps to linearise into one row, then allocate only 1 time step + // If there are not enough time steps to linearise into one row, then allocate only 1 time step. unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? in_time / num_steps_one_row : 1; float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); - unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here + unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here. for (t_in_start = 0, t_in_end = kernel_size - 1, t_out = 0; t_in_start < padding && t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { if (t_in_end < padding) { - // Filter outside the input region and in the padded region + // Filter outside the input region and in the padded region. memset(output_signal + t_out * out_channels, 0, out_channels * sizeof(float)); } - else { //(t_in_end >= padding) - // Filter partially entered the input + else { //(t_in_end >= padding). + // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W + (padding - t_in_start) * in_channels, input_signal, out_channels, (t_in_end - padding + 1) * in_channels, ncols, 1, 0, temp_out); @@ -379,29 +379,29 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe temp_out, out_channels * sizeof(float)); } } - // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases - // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region - // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row - // Using the above logic, we can convert the MatVec opeartion into a MatMul operation - // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied) - // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria - // The MatVec will be used for the computation in-case the following block is skipped + // The main part => the filter is fully inside the input. We can think of the non-overlapping cases as parallel cases. + // Each of the iterations are for the kernel striding to the next point till the filter is out of the overlapping region. + // Hence we use the num_steps_one_row for calculating the number of time steps to be linearized in one row. + // Using the above logic, we can convert the MatVec opeartion into a MatMul operation. + // Ideally both implementation would be the same. However for edge devices the matMul was found to be faster matVec (both tilied). + // Skip if atleast 2 rows cannot be formed. The condition 2 * num_steps_one_row + stride is the worst case criteria. + // The MatVec will be used for the computation in-case the following block is skipped. if (in_time > ((num_steps_one_row << 1) + stride)) { - t_in_start -= padding; // remove the padding offset temporarily - t_in_end -= padding; // Used to keep track of the final processed index + t_in_start -= padding; // remove the padding offset temporarily. + t_in_end -= padding; // Used to keep track of the final processed index. for (unsigned iter = 0; (iter < num_iter) && (t_out < out_channels); iter++, t_in_start += stride, t_out++) { unsigned in_rows = (in_time - t_in_start) / num_steps_one_row; memset(temp_out, 0, buffer_steps * out_channels * sizeof(float)); if (t_in_end < (t_in_start + ((in_rows - 1) * num_steps_one_row))) { - // t_in_end is used to find the furthest time step was used in the MatMul calculation - // This value will be used for calculating the index for the final section of the processing + // t_in_end is used to find the furthest time step was used in the MatMul calculation. + // This value will be used for calculating the index for the final section of the processing. t_in_end = ((in_rows - 1) * num_steps_one_row) + t_in_start + stride; } transposed_tiledMatMul(input_signal + t_in_start * in_channels , tparams->W, in_rows, ncols, out_channels, total_in_cols, ncols, temp_out, tparams->block_size); - // Copy all the data into the output + // Copy all the data into the output. float* output_offset = (float*)output_signal + t_out * out_channels; float* temp_offset = (float*)temp_out; unsigned t_iter = in_rows, offset_factor_for_out = num_iter * out_channels; @@ -411,17 +411,17 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe temp_offset += out_channels; } } - // Initialize the time iterators - // Use the stored value in t_in_end to calculate the iterators - t_in_start = t_in_end + padding; // Add the padding and stride offsets again + // Initialize the time iterators. + // Use the stored value in t_in_end to calculate the iterators. + t_in_start = t_in_end + padding; // Add the padding and stride offsets again. t_in_end = t_in_start + kernel_size - 1; t_out = t_in_start / stride; } for (; t_out < out_time; t_out++, t_in_start += stride, t_in_end += stride) { if (t_in_start < (in_time + padding) && (t_in_end < (in_time + padding))) { // Filter fully in the input but very close to the edges. - // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped - // Incase the MatMul is skipped, this block will be used to compute the results + // Due to the num_steps_one_row divisibility usage in the parallel step, some computations would be skipped. + // Incase the MatMul is skipped, this block will be used to compute the results. offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, out_channels, kernel_size * in_channels, @@ -430,10 +430,10 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe temp_out, out_channels * sizeof(float)); } else if (t_in_start < (in_time + padding) && (t_in_end >= (in_time + padding))) { - // Filter partially exited the input + // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. - // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels) - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix + // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). + // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, out_channels, (in_time + padding - t_in_start) * in_channels, @@ -442,7 +442,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe temp_out, out_channels * sizeof(float)); } else { - // Filter completely outside the input and in the padding region + // Filter completely outside the input and in the padding region. memset(output_signal + t_out * out_channels, 0, out_channels * sizeof(float)); } @@ -451,7 +451,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe for (t_out = 0; t_out < out_time; t_out++) { unsigned t_index = t_out * out_channels; for (unsigned co = 0; co < out_channels; co++) { - // Post-Conv activation. More activation functions can be added should the necessity arise + // Post-Conv activation. More activation functions can be added should the necessity arise. switch (activation) { case 1 : output_signal[t_index + co] = sigmoid(output_signal[t_index + co] + @@ -482,8 +482,8 @@ int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation) { - // Iterate over the time steps and average them - float scale = 1.0/(float)kernel_size; // To avoid divisions + // Iterate over the time steps and average them. + float scale = 1.0/(float)kernel_size; // To avoid divisions. for (unsigned t_in = 0, t_out = 0; t_out < out_time; t_out++, t_in += stride) { for (unsigned ci = 0; ci < in_channels; ci++) { float sum = 0; @@ -524,10 +524,10 @@ int batchnorm1d(float* output_signal, float* input_signal, unsigned in_place, float eps) { float* ret = in_place ? (float*)input_signal : (float*)output_signal; - // Check for affine_config - // = 1 ; Use gamma, beta, mean and var - // = 2 ; Use only gamma and beta - // = 3 ; Use only mean and var + // Check for affine_config. + // = 1 ; Use gamma, beta, mean and var. + // = 2 ; Use only gamma and beta. + // = 3 ; Use only mean and var. if (affine_config == 1) { while (in_time--) { float* gamma_offset = (float*)gamma; diff --git a/c_reference/src/dscnn.c b/c_reference/src/dscnn.c index a304ff54f..ef245837a 100644 --- a/c_reference/src/dscnn.c +++ b/c_reference/src/dscnn.c @@ -17,24 +17,24 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal, unsigned out_time = in_time - cnn_kernel_size + 2 * cnn_padding + 1; if (in_place) { - // BatchNorm + // BatchNorm. batchnorm1d(0, input_signal, in_time, in_channels, mean, var, affine_config, gamma, beta, in_place, 0.00001); - // CNN + // CNN. cnn(output_signal, out_time, cnn_hidden, input_signal, in_time, in_channels, cnn_padding, cnn_kernel_size, cnn_params, cnn_stride, cnn_activation); } else { - // BatchNorm + // BatchNorm. float* norm_out = (float*)malloc(in_time * in_channels * sizeof(float)); batchnorm1d(norm_out, input_signal, in_time, in_channels, mean, var, affine_config, gamma, beta, in_place, 0.00001); - // CNN + // CNN. cnn(output_signal, out_time, cnn_hidden, norm_out, in_time, in_channels, cnn_padding, cnn_kernel_size, cnn_params, cnn_stride, cnn_activation); @@ -53,8 +53,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, const void* point_cnn_params, unsigned point_cnn_stride, unsigned point_cnn_activation, unsigned pool_padding, unsigned pool_kernel_size, unsigned pool_stride, unsigned pool_activation) { - // Activation - + // Activation. float* act_out= (float*)malloc(in_time * (in_channels >> 1) * sizeof(float)); semi_sigmoid_tanh(act_out, input_signal, in_time, in_channels); @@ -62,13 +61,13 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, float* depth_out; unsigned out_time = in_time - depth_cnn_kernel_size + 2 * depth_cnn_padding + 1; if (in_place) { - // Norm + // Norm. batchnorm1d(0, act_out, in_time, in_channels, mean, var, affine_config, gamma, beta, in_place, 0.00001); - // Depth CNN + // Depth CNN. depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); conv1d(depth_out, out_time, 0, act_out, in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, @@ -76,7 +75,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, free(act_out); } else { - // Norm + // Norm. float* norm_out = (float*)malloc(in_time * in_channels * sizeof(float)); batchnorm1d(norm_out, act_out, in_time, in_channels, @@ -84,7 +83,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, affine_config, gamma, beta, in_place, 0.00001); free(act_out); - // Depth CNN + // Depth CNN. depth_out = (float*)malloc(out_time * in_channels * sizeof(float)); conv1d(depth_out, out_time, 0, norm_out, in_time, in_channels, depth_cnn_padding, depth_cnn_kernel_size, @@ -92,7 +91,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, free(norm_out); } - // Point CNN + // Point CNN. in_time = out_time; out_time = in_time - point_cnn_kernel_size + 2 * point_cnn_padding + 1; float* point_out = (float*)malloc(out_time * point_cnn_hidden * sizeof(float)); @@ -101,7 +100,7 @@ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, point_cnn_params, point_cnn_stride, point_cnn_activation); free(depth_out); - // Pool + // Pool. in_time = out_time; out_time = in_time - pool_kernel_size + 2 * pool_padding + 1; avgpool1d(output_signal, out_time, point_out, diff --git a/c_reference/src/rnn_bricked.c b/c_reference/src/rnn_bricked.c index 041ae8f05..e2da02995 100644 --- a/c_reference/src/rnn_bricked.c +++ b/c_reference/src/rnn_bricked.c @@ -7,24 +7,24 @@ #include "rnn_bricked.h" #include "utils.h" -// Forward Pass +// Forward Pass. int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, float* input_signal, unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, const void* params, unsigned bi_direction, unsigned sample_first_brick) { - // Buffers and params + // Buffers and params. const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; unsigned rnn_assign_offset = rnn_hidden, out_index = 0; unsigned num_bricks = (in_time - window) / hop + 1; - // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden - // This function only processes the forward context + // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden. + // This function only processes the forward context. if (bi_direction) { rnn_assign_offset <<= 1; } - // Compute W1 * W2 * X + // Compute W1 * W2 * X. float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); @@ -36,20 +36,20 @@ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, rnn_hidden, tparams->wRank, tparams->wRank, inputMulW, tparams->block_size_w_from_lr); free(tempLR); - // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch - // memset is used. Hence, malloc can be used here for matMul result initialization + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch. + // memset is used. Hence, malloc can be used here for matMul result initialization. tempLR = (float*)malloc(num_bricks * tparams->uRank * sizeof(float)); for (unsigned t = 0; t < window; t++) { - // From higher dims to lower dims + // From higher dims to lower dims. memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, tparams->uRank, rnn_hidden, rnn_hidden, tempLR, tparams->block_size_u_to_lr); - // From lower dims to higher dims - // Add Wx with Uh - // The tiled MatMuls are codes such that they yield result += matA * matB - // Hence we use calloc and memset to equate the result to 0 - // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input + // From lower dims to higher dims. + // Add Wx with Uh. + // The tiled MatMuls are codes such that they yield result += matA * matB. + // Hence we use calloc and memset to equate the result to 0. + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input. float* preComp_offset = (float*)preComp; for (unsigned n = 0; n < num_bricks; n++) { float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; @@ -74,7 +74,7 @@ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, rnn_hidden, tparams->uRank, tparams->uRank, preComp, tparams->block_size_u_from_lr); - // Apply the gating + // Apply the gating. float* hiddenState_offset = (float*)hiddenState; preComp_offset = (float*)preComp; unsigned bricks = num_bricks; @@ -124,7 +124,7 @@ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, hiddenState_offset++; } } - // Sample first block if necessary + // Sample first block if necessary. if (sample_first_brick) { if (t % hop == 0) { memcpy(output_signal + (out_index++) * rnn_assign_offset, @@ -133,7 +133,7 @@ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, } } if (bi_direction) { - // If bi-directional then a gap would need to be left for the backward outputs + // If bi-directional then a gap would need to be left for the backward outputs. float* hiddenState_offset = hiddenState; for (unsigned n = 0; n < num_bricks; n++) { memcpy(output_signal + (out_index++) * rnn_assign_offset, @@ -142,7 +142,7 @@ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, } } else { - // If only forward is needed, the the whole block of memory can be copied without the loop + // If only forward is needed, the the whole block of memory can be copied without the loop. memcpy(output_signal + out_index * rnn_assign_offset, hiddenState, num_bricks * rnn_hidden * sizeof(float)); } @@ -153,25 +153,25 @@ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, return 0; } -// Backward Pass +// Backward Pass. int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, float* input_signal, unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, const void* params, unsigned bi_direction, unsigned sample_last_brick) { - // Buffers and params + // Buffers and params. const BrickedFastGRNN_LR_Params* tparams = (const BrickedFastGRNN_LR_Params*)params; unsigned rnn_assign_offset = rnn_hidden; unsigned num_bricks = (in_time - window) / hop + 1; unsigned out_index = in_time / hop; // = out_time - 1; - // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden - // This function only processes the forward context + // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden. + // This function only processes the forward context. if (bi_direction) { rnn_assign_offset <<= 1; } - // Compute W1 * W2 * X + // Compute W1 * W2 * X. float* inputMulW = (float*)calloc(in_time * rnn_hidden, sizeof(float)); float* tempLR = (float*)calloc(in_time * tparams->wRank, sizeof(float)); float* hiddenState = (float*)calloc(num_bricks * rnn_hidden, sizeof(float)); @@ -183,19 +183,19 @@ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, rnn_hidden, tparams->wRank, tparams->wRank, inputMulW, tparams->block_size_w_from_lr); free(tempLR); - // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch + // We can reuse the low-rank buffer from Wx to Uh, since Wx is computed at one stretch. tempLR = (float*)calloc(num_bricks * tparams->uRank, sizeof(float)); for (int t = window - 1; t >= 0; t--) { - // From higher dims to lower dims + // From higher dims to lower dims. memset(tempLR, 0, num_bricks * tparams->uRank * sizeof(float)); transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, tparams->uRank, rnn_hidden, rnn_hidden, tempLR, tparams->block_size_u_to_lr); - // From lower dims to higher dims - // Add Wx with Uh - // The tiled MatMuls are codes such that they yield result += matA * matB - // Hence we use calloc and memset to equate the result to 0 - // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input + // From lower dims to higher dims. + // Add Wx with Uh. + // The tiled MatMuls are codes such that they yield result += matA * matB. + // Hence we use calloc and memset to equate the result to 0. + // But since we want Wx + Uh, we can store Wx and use the MatMul to add the result over the input. float* preComp_offset = (float*)preComp; for (unsigned n = 0; n < num_bricks; n++) { float* inputMulW_offset = (float*)inputMulW + (n * hop + t) * rnn_hidden; @@ -220,7 +220,7 @@ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, rnn_hidden, tparams->uRank, tparams->uRank, preComp, tparams->block_size_u_from_lr); - // Apply the gating + // Apply the gating. float* hiddenState_offset = (float*)hiddenState; preComp_offset = (float*)preComp; unsigned bricks = num_bricks; @@ -270,19 +270,19 @@ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, hiddenState_offset++; } } - // Sample first block if necessary + // Sample first block if necessary. if (sample_last_brick) { if ((window - 1 - t) % hop == 0) { - // Iterate over the output in reverse + // Iterate over the output in reverse. memcpy(output_signal + (out_index--) * rnn_assign_offset, hiddenState + (num_bricks - 1) * rnn_hidden, rnn_hidden * sizeof(float)); } } } - // Since the all first (final in reverse) hiddenstates are calculated, we assign the whole block + // Since the all first (final in reverse) hiddenstates are calculated, we assign the whole block. out_index = 0; if (bi_direction) { - // If bi-directional then a gap would need to be left for the backward outputs + // If bi-directional then a gap would need to be left for the backward outputs. float* hiddenState_offset = hiddenState; for (unsigned n = 0; n < num_bricks; n++) { memcpy(output_signal + (out_index++) * rnn_assign_offset, @@ -291,7 +291,7 @@ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, } } else { - // If only forward is needed, the the whole block of memory can be copied without the loop + // If only forward is needed, the the whole block of memory can be copied without the loop. memcpy(output_signal + out_index * rnn_assign_offset, hiddenState, num_bricks * rnn_hidden * sizeof(float)); } diff --git a/c_reference/src/utils.c b/c_reference/src/utils.c index 0373d0c0b..b9b9edaa3 100644 --- a/c_reference/src/utils.c +++ b/c_reference/src/utils.c @@ -77,8 +77,8 @@ void offset_matVec_conv1d(const float* mat, const float* vec, unsigned depthwise, float* ret) { while (nrows--) { - // For depthwise, the vec(input) pointer is updated - // Since each row of the mat corresponds to a separate channel index + // For depthwise, the vec(input) pointer is updated. + // Since each row of the mat corresponds to a separate channel index. float* vec_offset = depthwise ? (float*)vec++ : (float*)vec; float* mat_offset = (float*)mat; float sum = 0.0f; @@ -128,7 +128,7 @@ void tiledMatMul_float(const float* const matA, const float* const matB, #ifdef LOOP_UNROLL unsigned len_unroll = temp_block_size >> 2; - temp_block_size %= 4; // comm_block_size % 4 + temp_block_size %= 4; // comm_block_size % 4. while (len_unroll--) { sum += (*matA_offset++) * (*matB_offset); matB_offset += ncols; @@ -173,7 +173,7 @@ void transposed_tiledMatMul(const float* const matA, const float* const matB, #ifdef LOOP_UNROLL unsigned len_unroll = temp_block_size >> 2; - temp_block_size %= 4; // comm_block_size % 4 + temp_block_size %= 4; // comm_block_size % 4. while (len_unroll--) { sum += (*matA_offset++) * (*matB_offset++); sum += (*matA_offset++) * (*matB_offset++); @@ -245,7 +245,7 @@ void softmax(const float* const input, unsigned len, float* const ret) { void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, unsigned in_time, unsigned in_channels) { - unsigned time_step = 0; // used to avoid index multiplication + unsigned time_step = 0; // used to avoid index multiplication. while (in_time--) { unsigned pivot = in_channels >> 1; float* input_sigmoid_offset = (float*)input_signal + time_step; diff --git a/c_reference/tests/conv1d/test_conv1d.c b/c_reference/tests/conv1d/test_conv1d.c index 189b4f257..969ebb3bc 100644 --- a/c_reference/tests/conv1d/test_conv1d.c +++ b/c_reference/tests/conv1d/test_conv1d.c @@ -20,6 +20,8 @@ void errorCheck(float* pred, float* label, unsigned out_time, int out_features) denom += label[t * out_features + d] * label[t * out_features + d]; } } + // RMSE - Relative Mean Squared Error. + // The ratio of the Squared Error to the Squared Summation of the Signal. float avg_error = error / (out_time * out_features), rmse = error / denom; printf("Agg Squared Error: %f ; MSE: %f ; RMSE: %f\n", error, avg_error, rmse); } diff --git a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c index ff204c8ef..18104d0e8 100644 --- a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c +++ b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c @@ -15,7 +15,7 @@ #include "rnn_params.h" #include "postcnn_params.h" -// Check number of output time-steps with the number of label time-steps +// Check number of output time-steps with the number of label time-steps. int checkTime(unsigned out_time) { if (out_time != KWS_OUT_TIME) { printf("Error, estimated and actual ouput time-steps mismatch"); @@ -39,24 +39,26 @@ void checkError(float* pred, float* label) { printf("Full Network\n"); printf("Agg Squared Error : %f\n", error); printf("MSE : %f\n", error / (KWS_OUT_TIME*POST_CNN_OUT_FEATURES)); + // RMSE - Relative Mean Squared Error. + // The ratio of the Squared Error to the Squared Summation of the Signal. printf("RMSE : %f\n", error / denom); } -/* CNN-RNN based Phoneme Detection Model +/* CNN-RNN based Phoneme Detection Model. The phoneme detection model used consists of 6 blocks. - 1st block is a CNN, where kernel size is 5 and regular tanh activation + 1st block is a CNN, where kernel size is 5 and regular tanh activation. 2nd block is an RNN, which has a specified forward and a backward context running at a stride/hop of 3. Hence it reduces the sequence length by a factor of 3. - Rest of the blocks(3rd, 4th, 5th and 6th) are a combination of CNNs - Each of the final 4 blocks consist of a depth cnn (kernel size of 5) and a point cnn (kernel size of 1) + Rest of the blocks(3rd, 4th, 5th and 6th) are a combination of CNNs. + Each of the final 4 blocks consist of a depth cnn (kernel size of 5) and a point cnn (kernel size of 1). Input to the architecture is of the form (seq_len, feature_dim) where feature dim refers to n_mels (number of mel features/number of features from the featurizer). Output is of the form (seq_len/3, 41) where 41 is the number of phonemes over which the classification is performed. Phonemes are predicted for every 3rd time frame, operating under the assumption that they don't vary faster than that. - NOTE: Before deployment for real-time streaming applications, we would need to make minor modification - These changes are subject to the input specs i.e fixing input buffer time steps, number of features from the deployed featurizer, method of reading the input into a buffer + NOTE: Before deployment for real-time streaming applications, we would need to make minor modification. + These changes are subject to the input specs i.e fixing input buffer time steps, number of features from the deployed featurizer, method of reading the input into a buffer. */ void phoneme_prediction(float* mem_buf) { ConvLayers_LR_Parallel_Params conv_params = { @@ -169,12 +171,12 @@ void phoneme_prediction(float* mem_buf) { out_time = in_time - PRE_CNN_FILT + (PRE_CNN_FILT_PAD << 1) + 1; float* cnn1_out = (float*)malloc(out_time * PRE_CNN_OUT_FEATURES * sizeof(float)); // Since batchnorm1d is the first layer and in-place will alter the input. - // Use the in-place computation only if the input can be discarded/altered. Else avoid in-place computation for this layer + // Use the in-place computation only if the input can be discarded/altered. Else avoid in-place computation for this layer. phon_pred_lr_cnn(cnn1_out, mem_buf, conv1d_lr_parallel, in_time, PRE_CNN_IN_FEATURES, 0, 0, PRE_CNN_BNORM_AFFINE, CNN1_SCALE, CNN1_OFFSET, PRE_CNN_BNORM_INPLACE, PRE_CNN_OUT_FEATURES, PRE_CNN_FILT_PAD, PRE_CNN_FILT, - &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // regular tanh activation + &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // regular tanh activation. batchnorm1d(0, cnn1_out, in_time, RNN_IN_FEATURES, 0, 0, RNN_BNORM_AFFINE, RNN_SCALE, RNN_OFFSET, 1, 0.00001); @@ -193,8 +195,8 @@ void phoneme_prediction(float* mem_buf) { free(cnn1_out); /* Post-CNN */ - // Since all inputs to the subsequent layers are temporary, in-place batchnorm1d can be used without any input(initial buffer)/output(final layer) data alteration/corruption - // CNN2 + // Since all inputs to the subsequent layers are temporary, in-place batchnorm1d can be used without any input(initial buffer)/output(final layer) data alteration/corruption. + // CNN2. in_time = out_time; out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; @@ -209,7 +211,7 @@ void phoneme_prediction(float* mem_buf) { POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); free(rnn_out); - // CNN3 + // CNN3. in_time = out_time; out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; @@ -224,7 +226,7 @@ void phoneme_prediction(float* mem_buf) { POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); free(cnn2_out); - // CNN4 + // CNN4. in_time = out_time; out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; @@ -239,7 +241,7 @@ void phoneme_prediction(float* mem_buf) { POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); free(cnn3_out); - // CNN5 + // CNN5. in_time = out_time; out_time = in_time - POST_CNN_DEPTH_FILT + (POST_CNN_DEPTH_PAD << 1) + 1; out_time = out_time - POST_CNN_POOL + (POST_CNN_POOL_PAD << 1) + 1; diff --git a/c_reference/tests/rnn_bricked/test_rnn_bricked.c b/c_reference/tests/rnn_bricked/test_rnn_bricked.c index 701d73af4..b2f03696d 100644 --- a/c_reference/tests/rnn_bricked/test_rnn_bricked.c +++ b/c_reference/tests/rnn_bricked/test_rnn_bricked.c @@ -64,6 +64,8 @@ int main() { denom += OUTPUT[t * RNN_OUT_FEATURES + d] * OUTPUT[t * RNN_OUT_FEATURES + d]; } } + // RMSE - Relative Mean Squared Error. + // The ratio of the Squared Error to the Squared Summation of the Signal. float avg_error = error / (RNN_OUT_TIME * RNN_OUT_FEATURES); float rmse = error / denom; From 85943f877ed024244f92ab8948e67879aab64a8d Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Wed, 20 Oct 2021 17:06:40 -0700 Subject: [PATCH 3/4] Correcting the punctuations --- c_reference/include/conv1d.h | 246 +++++++++--------- c_reference/include/dscnn.h | 132 +++++----- c_reference/include/rnn_bricked.h | 72 ++--- c_reference/include/utils.h | 112 ++++---- c_reference/src/conv1d.c | 24 +- c_reference/src/rnn_bricked.c | 2 +- c_reference/src/utils.c | 4 +- c_reference/tests/conv1d/test_conv1d.c | 2 +- .../tests/kws/test_phoneme_det_cnn_rnn.c | 15 +- 9 files changed, 303 insertions(+), 306 deletions(-) diff --git a/c_reference/include/conv1d.h b/c_reference/include/conv1d.h index 2547d8a03..e92f78727 100644 --- a/c_reference/include/conv1d.h +++ b/c_reference/include/conv1d.h @@ -4,7 +4,7 @@ #ifndef __CONV1D_H__ #define __CONV1D_H__ -/* All the matrices/tensors are stored in the row major format +/* All the matrices/tensors are stored in the row major format. NOTES for the conv layers. -> The conv1d & conv1d_lr layers work for all cases and can be used unconstrained. @@ -34,10 +34,10 @@ */ /** - * @brief Model parameters for the 1D Convolution Layer - * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] - * @var B pointer to the bias vector, original shape = [out_channels] - * @var depthwise flag for deciding between regular(=0) and depthwise(=1) conv + * @brief Model parameters for the 1D Convolution Layer. + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1]. + * @var B pointer to the bias vector, original shape = [out_channels]. + * @var depthwise flag for deciding between regular(=0) and depthwise(=1) conv. */ typedef struct ConvLayers_Params { const float* const W; @@ -46,25 +46,25 @@ typedef struct ConvLayers_Params { } ConvLayers_Params; /** - * @brief Model definition for the 1D Convolution Layer. Currently only for dilation = 1 - * @param[out] output_signal pointer to the output signal, size = out_time * out_channels - * @param[in] out_time number of time steps in the output - * @param[in] out_channels number of output channels for the output of the conv layer - * NOTE: out_channels = in_channels for depthwise. This is set manually in the function - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed - * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) - * @param[in] kernel_size kernel size of the conv filter - * @param[in] params weights, bias and other essential parameters used to describe the layer - * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu + * @brief Model definition for the 1D Convolution Layer. Currently only for dilation = 1. + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * NOTE: out_channels = in_channels for depthwise. This is set manually in the function. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * NOTE: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. */ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, unsigned in_time, unsigned in_channels, @@ -72,10 +72,10 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, const void* params, unsigned stride, unsigned activation); /** - * @brief Model parameters for the 1D Parallel Convolution Layer - * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1] - * @var B pointer to the bias vector, original shape = [out_channels] - * @var block_size block/tile size for the cache. Used for tiled MatMul + * @brief Model parameters for the 1D Parallel Convolution Layer. + * @var W pointer to the flattened conv weights, original shape for regular = [out_channels, kernel_size, in_channels], shape for depthwise = [in_channels, kernel_size, 1]. + * @var B pointer to the bias vector, original shape = [out_channels]. + * @var block_size block/tile size for the cache. Used for tiled MatMul. */ typedef struct ConvLayers_Parallel_Params { const float* const W; @@ -85,23 +85,23 @@ typedef struct ConvLayers_Parallel_Params { /** * @brief Model definition for the 1D Parallel Convolution Layer. Currently only for dilation = 1. No depthwise. - * @param[out] output_signal pointer to the output signal, size = out_time * out_channels - * @param[in] out_time number of time steps in the output - * @param[in] out_channels number of output channels for the output of the conv layer - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed - * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) - * @param[in] kernel_size kernel size of the conv filter - * @param[in] params weights, bias and other essential parameters used to describe the layer - * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. */ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, unsigned in_time, unsigned in_channels, @@ -110,10 +110,10 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe /** * @brief Model parameters for the 1D Low Rank Convolution Layer. - * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels - * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1] - * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels] - * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels. + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1]. + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels]. + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage. */ typedef struct ConvLayers_LR_Params { const float* const W1; @@ -124,24 +124,24 @@ typedef struct ConvLayers_LR_Params { /** * @brief Model definition for the 1D Low-Rank Convolution Layer. Currently only for dilation = 1. - * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels) - * @param[out] output_signal pointer to the output signal, size = out_time * out_channels - * @param[in] out_time number of time steps in the output - * @param[in] out_channels number of output channels for the output of the conv layer - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed - * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) - * @param[in] kernel_size kernel size of the conv filter - * @param[in] params weights, bias and other essential parameters used to describe the layer - * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels). + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. */ int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, unsigned in_time, unsigned in_channels, @@ -150,12 +150,12 @@ int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, /** * @brief Model parameters for the 1D Low Rank Parallel Convolution Layer. - * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels - * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1] - * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels] - * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage - * @var block_size_to_lr block/tile size for the cache. Used for tiled MatMul. Used for the input -> low-rank computation - * @var block_size_from_lr block/tile size for the cache. Used for tiled MatMul. Used for the low-rank -> output computation + * @var W1 pointer to the flattened 1st low-rank component of the weights, original shape = [out_channels, rank]. For depthwise out_channels = in_channels. + * @var W2 pointer to the flattened 2nd low-rank component of the weights, original shape for regular = [rank, kernel_size, in_channels], shape for depthwise = [rank, kernel_size, 1]. + * @var B pointer to the flattened bias vector for the convolution, original shape = [out_channels]. + * @var rank rank of the weight tensor. A low-rank decomposition typically used to reduce computation and storage. + * @var block_size_to_lr block/tile size for the cache. Used for tiled MatMul. Used for the input -> low-rank computation. + * @var block_size_from_lr block/tile size for the cache. Used for tiled MatMul. Used for the low-rank -> output computation. */ typedef struct ConvLayers_LR_Parallel_Params { const float* const W1; @@ -168,71 +168,71 @@ typedef struct ConvLayers_LR_Parallel_Params { /** * @brief Model definition for the 1D Low-Rank Parallel Convolution Layer. Currently only for dilation = 1. - * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels) - * @param[out] output_signal pointer to the output signal, size = out_time * out_channels - * @param[in] out_time number of time steps in the output - * @param[in] out_channels number of output channels for the output of the conv layer - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels - * @param[in] padding padding applied to the input before the conv is performed - * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) - * @param[in] kernel_size kernel size of the conv filter - * @param[in] params weights, bias and other essential parameters used to describe the layer - * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu + * @brief Low-Rank and depthwise are incompatible as the low-rank decomposition of the weight matrix violates the depthwise conditions (out_channels % groups = 0, where groups = in_channels). + * @param[out] output_signal pointer to the output signal, size = out_time * out_channels. + * @param[in] out_time number of time steps in the output. + * @param[in] out_channels number of output channels for the output of the conv layer. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the conv filter. + * @param[in] params weights, bias and other essential parameters used to describe the layer. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. */ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_channels, const float* input_signal, unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, const void* params, unsigned stride, unsigned activation); -// Auxiliary Layers +// Auxiliary Layers. /** - * @brief Model definition for the 1D Average Pooling Layer. Currently only for dilation = 1 - * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation - * NOTE: out_channels == in_channels for avgpool - * @param[in] out_time number of time steps in the output - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels. The output will have the same number of channels - * @param[in] padding padding applied to the input before the conv is performed - * Note: padding is applied to both the starting and ending of the input, along the time axis - * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1) - * @param[in] kernel_size kernel size of the pool filter - * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1 - * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity - * 0: none + * @brief Model definition for the 1D Average Pooling Layer. Currently only for dilation = 1. + * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation. + * NOTE: out_channels == in_channels for avgpool. + * @param[in] out_time number of time steps in the output. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. The output will have the same number of channels. + * @param[in] padding padding applied to the input before the conv is performed. + * Note: padding is applied to both the starting and ending of the input, along the time axis. + * E.g : padding = 3, the input is padded with zeros(for 3 time steps), both before the input_signal(time step 0) and after the input_signal(time step in_time-1). + * @param[in] kernel_size kernel size of the pool filter. + * @param[in] stride stride length for the layer. input_time_iterator += stride for output_time_iterator +=1. + * @param[in] activation an integer to choose the type of activation function. More can be added as per the necessity. + * 0: none. * 1: sigmoid - * 2: tanh - * 3: relu + * 2: tanh. + * 3: relu. */ int avgpool1d(float* output_signal, unsigned out_time, const float* input_signal, unsigned in_time, unsigned in_channels, unsigned padding, unsigned kernel_size, unsigned stride, unsigned activation); /** - * @brief Model definition for the 1D batch Normalization Layer - * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels. The output will have the same number of channels - * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0 - * @param[in] var pointer to the variance for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0 - * @param[in] affine_config whether the affine operations are applied - * if affine_config = 0, then only mean and var are used + * @brief Model definition for the 1D batch Normalization Layer. + * @param[out] output_signal pointer to the output signal, size = out_time * in_channels. Provide Null/0 in case of in-place computation. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. The output will have the same number of channels. + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0. + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. if affine_config = 2, then pass a NULL/0. + * @param[in] affine_config whether the affine operations are applied. + * if affine_config = 0, then only mean and var are used. * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. - * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var) - * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed - * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0 - * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0 - * @param[in] in_place in-place computation of the batchnorm i.e. the output is stored in-place of the input signal. Storage efficient - * @param[in] eps a very small +ve value to avoid division by 0. For the default value, assign = 0.00001 + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var). + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed. + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0. + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Provide Null/0 if affine_config is 0. + * @param[in] in_place in-place computation of the batchnorm i.e. the output is stored in-place of the input signal. Storage efficient. + * @param[in] eps a very small +ve value to avoid division by 0. For the default value, assign = 0.00001. */ int batchnorm1d(float* output_signal, float* input_signal, unsigned in_time, unsigned in_channels, diff --git a/c_reference/include/dscnn.h b/c_reference/include/dscnn.h index 541923056..1833d0813 100644 --- a/c_reference/include/dscnn.h +++ b/c_reference/include/dscnn.h @@ -4,39 +4,39 @@ #ifndef __DSCNN_H__ #define __DSCNN_H__ -// Function pointer for the Conv layer to be passed as a parameter. (conv1d or conv1d_lr only) +// Function pointer for the Conv layer to be passed as a parameter. (conv1d or conv1d_lr only). typedef int (*conv_layer)(float*, unsigned, unsigned, const float*, unsigned, unsigned, unsigned, unsigned, const void*, unsigned, unsigned); /** - * @brief Model definition for the 1D Convolution block applied before the RNN - * @brief sub-layers : batchnorm1d -> conv1d_lr - * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params) - * @param[in] in_time number of time steps in the input_signal - * @param[in] in_channels number of input channels - * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 - * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 - * @param[in] affine_config whether the affine operations are applied - * if affine_config = 0, then only mean and var are used + * @brief Model definition for the 1D Convolution block applied before the RNN. + * @brief sub-layers : batchnorm1d -> conv1d_lr. + * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] cnn function pointer for the CNN layer. (any of the conv layers can be passed with appropriate params). + * @param[in] in_time number of time steps in the input_signal. + * @param[in] in_channels number of input channels. + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] affine_config whether the affine operations are applied. + * if affine_config = 0, then only mean and var are used. * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. - * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var) - * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed - * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 - * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 - * @param[in] in_place in-place computation check for the batchnorm. Storage efficient - * @param[in] cnn_hidden hidden state/out_channels dimensions for the low-rank CNN. The final channel size of this block - * @param[in] cnn_padding padding for the low-rank CNN layer. Note: applied to both sides of the input - * @param[in] cnn_kernel_size kernel size of the low-rank CNN - * @param[in] cnn_params weights, bias and other essential parameters for the low-rank CNN - * @param[in] cnn_stride stride factor for the low-rank CNN + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var). + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed. + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] in_place in-place computation check for the batchnorm. Storage efficient. + * @param[in] cnn_hidden hidden state/out_channels dimensions for the low-rank CNN. The final channel size of this block. + * @param[in] cnn_padding padding for the low-rank CNN layer. Note: applied to both sides of the input. + * @param[in] cnn_kernel_size kernel size of the low-rank CNN. + * @param[in] cnn_params weights, bias and other essential parameters for the low-rank CNN. + * @param[in] cnn_stride stride factor for the low-rank CNN. * @param[in] cnn_activation an integer to choose the type of activation function. - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. */ int phon_pred_lr_cnn(float* output_signal, float* input_signal, conv_layer cnn, unsigned in_time, unsigned in_channels, @@ -46,50 +46,50 @@ int phon_pred_lr_cnn(float* output_signal, float* input_signal, const void* cnn_params, unsigned cnn_stride, unsigned cnn_activation); /** - * @brief Model definition for the 1D Convolution block applied after the RNN - * @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d - * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers - * @param[in] input_signal pointer to the input signal. size = in_time * in_channels - * @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params) - * @param[in] in_time number of time steps in the input - * @param[in] in_channels number of input channels - * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 - * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2 - * @param[in] affine_config whether the affine operations are applied - * if affine_config = 0, then only mean and var are used + * @brief Model definition for the 1D Convolution block applied after the RNN. + * @brief sub-layers : custom nonlinearity(semi_sigmoid_tanh) -> batchnorm1d -> conv1d_depth -> conv1d_lr -> avgpool1d. + * @param[out] output_signal pointer to the final output signal, minimum size = out_time * in_channels. out_time has to be calculated based on the reduction from all the conv and pool layers. + * @param[in] input_signal pointer to the input signal. size = in_time * in_channels. + * @param[in] point_cnn function pointer for the point-wise CNN. (any of the conv layers can be passed with appropriate params). + * @param[in] in_time number of time steps in the input. + * @param[in] in_channels number of input channels. + * @param[in] mean pointer to the mean for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] var pointer to the variance for the batch normalization, size = in_channels. Pass NULL/0 for affine_config = 2. + * @param[in] affine_config whether the affine operations are applied. + * if affine_config = 0, then only mean and var are used. * if affine_config = 1, then mean, var, gamma and beta are used for the final computation. - * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var) - * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed - * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 - * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0 - * @param[in] in_place in-place computation of the batchnorm. Storage efficient - * @param[in] depth_cnn_padding padding for the depth CNN layer. Note: applied to both sides of the input to the depth CNN - * @param[in] depth_cnn_kernel_size kernel size of the depth CNN - * @param[in] depth_cnn_params weights, bias and other essential parameters used to describe the depth CNN - * @param[in] depth_cnn_stride stride factor for the depth CNN + * if affine_config = 2, then only the gamma and beta are used. gamma = original_gamma/sqrt(var), beta = original_beta - gamma * mean/sqrt(var). + * Note: Use affine_config = 2 for faster calculations. The new gamma and beta would need to be pre-computed, stored and passed. + * @param[in] gamma pointer to the scaling factors for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] beta pointer to the offsets for the post-norm affine operation, size = in_channels. Pass NULL/0 for affine_config = 0. + * @param[in] in_place in-place computation of the batchnorm. Storage efficient. + * @param[in] depth_cnn_padding padding for the depth CNN layer. Note: applied to both sides of the input to the depth CNN. + * @param[in] depth_cnn_kernel_size kernel size of the depth CNN. + * @param[in] depth_cnn_params weights, bias and other essential parameters used to describe the depth CNN. + * @param[in] depth_cnn_stride stride factor for the depth CNN. * @param[in] depth_cnn_activation an integer to choose the type of activation function. - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu - * @param[in] point_cnn_hidden hidden state/out_channels dimensions for the point CNN. The final channel size of this block - * @param[in] point_cnn_padding padding for the point CNN layer. Note: applied to both sides of the input to the point CNN - * @param[in] point_cnn_kernel_size kernel size of the point CNN - * @param[in] point_cnn_params weights, bias and other essential parameters used to describe the point CNN - * @param[in] point_cnn_stride stride factor for the point CNN + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + * @param[in] point_cnn_hidden hidden state/out_channels dimensions for the point CNN. The final channel size of this block. + * @param[in] point_cnn_padding padding for the point CNN layer. Note: applied to both sides of the input to the point CNN. + * @param[in] point_cnn_kernel_size kernel size of the point CNN. + * @param[in] point_cnn_params weights, bias and other essential parameters used to describe the point CNN. + * @param[in] point_cnn_stride stride factor for the point CNN. * @param[in] point_cnn_activation an integer to choose the type of activation function. - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu - * @param[in] pool_padding padding for the pool layer. Note: applied to both sides of the input to the pool - * @param[in] pool_kernel_size kernel size of the pool - * @param[in] pool_stride stride factor for the pool + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. + * @param[in] pool_padding padding for the pool layer. Note: applied to both sides of the input to the pool. + * @param[in] pool_kernel_size kernel size of the pool. + * @param[in] pool_stride stride factor for the pool. * @param[in] pool_activation an integer to choose the type of activation function. - * 0: none - * 1: sigmoid - * 2: tanh - * 3: relu + * 0: none. + * 1: sigmoid. + * 2: tanh. + * 3: relu. */ int phon_pred_depth_point_lr_cnn(float* output_signal, float* input_signal, conv_layer point_cnn, unsigned in_time, unsigned in_channels, diff --git a/c_reference/include/rnn_bricked.h b/c_reference/include/rnn_bricked.h index a1f4bb696..a7e2d2658 100644 --- a/c_reference/include/rnn_bricked.h +++ b/c_reference/include/rnn_bricked.h @@ -31,21 +31,21 @@ */ /** - * @brief Model parameters for the 1D Convolution Layer - * @var W1 pointer to first low-rank component of W. shape = [rank * in_dims] - * @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank] - * @var wRank rank of W matrix - * @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden] - * @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank] - * @var uRank rank of U matrix - * @var Bg pointer to bias for sigmoid - * @var Bh pointer to bias for tanh - * @var sigmoid_zeta first weight parameter for update from input from next step - * @var sigmoid_nu second weight parameter for update from input from next step - * @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x - * @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x) - * @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h - * @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h) + * @brief Model parameters for the 1D Convolution Layer. + * @var W1 pointer to first low-rank component of W. shape = [rank * in_dims]. + * @var W2 pointer to second low-rank component of W. shape = [rnn_hidden * rank]. + * @var wRank rank of W matrix. + * @var U1 pointer to first low-rank component of U. shape = [rank * rnn_hidden]. + * @var U2 pointer to second low-rank component of U. shape = [rnn_hidden * rank]. + * @var uRank rank of U matrix. + * @var Bg pointer to bias for sigmoid. + * @var Bh pointer to bias for tanh. + * @var sigmoid_zeta first weight parameter for update from input from next step. + * @var sigmoid_nu second weight parameter for update from input from next step. + * @var block_size_w_to_lr block/tile size for the cache. Used for tiled MatMul. For W1 * x. + * @var block_size_w_from_lr block/tile size for the cache. Used for tiled MatMul. For W2 * result(W1 * x). + * @var block_size_u_to_lr block/tile size for the cache. Used for tiled MatMul. For U1 * h. + * @var block_size_u_from_lr block/tile size for the cache. Used for tiled MatMul. For U2 * result(U1 * h). */ typedef struct BrickedFastGRNN_LR_Params { float* W1; @@ -64,38 +64,38 @@ typedef struct BrickedFastGRNN_LR_Params { unsigned block_size_u_from_lr; } BrickedFastGRNN_LR_Params; -/** Forward Bricking and application of the forward RNN for an input signal - * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden - * @param[in] rnn_hidden output dimension for the current cell - * @param[in] input_signal pointer to input signal. size = in_time * in_dims +/** Forward Bricking and application of the forward RNN for an input signal. + * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden. + * @param[in] rnn_hidden output dimension for the current cell. + * @param[in] input_signal pointer to input signal. size = in_time * in_dims. * @param[in] in_time number of input time steps. - * @param[in] in_dims input dimensions - * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick) - * @param[in] hop hop distance for between bricks - * @param[in] params pointer to the parameters for the RNN + * @param[in] in_dims input dimensions. + * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick). + * @param[in] hop hop distance for between bricks. + * @param[in] params pointer to the parameters for the RNN. * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. - * @param[in] sample_first_brick determine if the 1st brick should also be sampled - * -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1 - * -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1 + * @param[in] sample_first_brick determine if the 1st brick should also be sampled. + * -> if = 0, only the last hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1. + * -> if = 1, for the 1st brick, we sample every hop index(similar to ::hop). For all the bricks(including the 1st) we sample the final hiddens state. out_time = in_time/hop + 1. */ int forward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, float* input_signal, unsigned in_time, unsigned in_dims, unsigned window, unsigned hop, const void* params, unsigned bi_direction, unsigned sample_first_brick); -/** Backward Bricking and application of the backward RNN for an input signal - * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden - * @param[in] rnn_hidden output dimension for the current cell - * @param[in] input_signal pointer to input signal. size = in_time * in_dims +/** Backward Bricking and application of the backward RNN for an input signal. + * @param[out] output_signal pointer to output signal. size = out_time * rnn_hidden. + * @param[in] rnn_hidden output dimension for the current cell. + * @param[in] input_signal pointer to input signal. size = in_time * in_dims. * @param[in] in_time number of input time steps. - * @param[in] in_dims input dimensions - * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick) - * @param[in] hop hop distance for between bricks - * @param[in] params pointer to the parameters for the RNN + * @param[in] in_dims input dimensions. + * @param[in] window window length for each brick. For the final brick, the left over time steps are used(need not be window in length for the last brick). + * @param[in] hop hop distance for between bricks. + * @param[in] params pointer to the parameters for the RNN. * @param[in] bi_direction determine if the ouput if for a bi-directional RNN. * @param[in] sample_last_brick determine if the last brick should also be sampled - * -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1 - * -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1 + * -> if = 0, only the first(last in reverse) hidden state of each brick is sampled. out_time = (in_time-window)/hop + 1. + * -> if = 1, for the last brick, we sample every hop index in reverse(similar to ::hop in reverse). For all the bricks(including the last) we sample the first hiddens state(last in reverse). out_time = in_time/hop + 1. */ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, float* input_signal, unsigned in_time, unsigned in_dims, diff --git a/c_reference/include/utils.h b/c_reference/include/utils.h index 07438d134..37a821a90 100644 --- a/c_reference/include/utils.h +++ b/c_reference/include/utils.h @@ -32,27 +32,27 @@ void matVec(const float* const mat, const float* const vec, float* const ret); /* - Matrix-vector multiplication with a row offset - This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis - ret is of size nrows, vec is of size ncols - mat is of size nrows * ncols, stored in row major - depthwise is to change the matVec to depthwise specific convolutions - row_stride is the offset factor between two adjacent rows - Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped - For a normal matVec case, this value will be ncols + Matrix-vector multiplication with a row offset. + This function was developed primarily for the conv1d function. This helps bypass the permutation of the time and channel axis. + ret is of size nrows, vec is of size ncols. + mat is of size nrows * ncols, stored in row major. + depthwise is to change the matVec to depthwise specific convolutions. + row_stride is the offset factor between two adjacent rows. + Note : This matrix-vector multiplication is useful for matrices where a certain number of columns are dropped. + For a normal matVec case, this value will be ncols. + Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. - Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. Eg : For a 400 x 400 matrix and a 100 length vector, we can consider the top 400 x 100 elements for the multiplication. - For this eg ncols will be 100 and row_stride will be 400 - vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals - For a normal matVec case, this value will be 1 + For this eg ncols will be 100 and row_stride will be 400. + vec_stride is the offset fector between 2 elements in a vector i.e. the elements of a vector are placed at "n" intervals. + For a normal matVec case, this value will be 1. Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. - Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. - So it's possible to enter a 400 length vector and consider every 4th element. - So it's possible to enter a 400 length vector and consider every 4th element. - So it's possible to enter a 400 length vector and consider every 4th element. - For this ncols will be 100 and vec_stride will be 4 + Eg : For matVec with a 400 x 100 matrix a vector of length 100 is needed. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + So it's possible to enter a 400 length vector and consider every 4th element. + For this ncols will be 100 and vec_stride will be 4. */ void offset_matVec_conv1d(const float* mat, const float* vec, unsigned nrows, unsigned ncols, @@ -60,23 +60,21 @@ void offset_matVec_conv1d(const float* mat, const float* vec, unsigned depthwise, float* ret); /* - Tiled (cache-blocked) implementation of the Matrix Multiplication - Note: If only the MatMul output is needed, then please use calloc to initialize the output - An alternative is to use malloc, followed by memset 0 - There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix - If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly - This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed - matA first matrix; shape = [nrows, ncommon] - matB second matrix; shape = [ncommon, ncols] - nrows number of rows in the first matrix - ncommon number of columns in the first matrix/number of rows in the second matrix - ncols number of columns in the second matrix - total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored - total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. - total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. - total_cols_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. - ret matrix multiplication output. shape = [nrows, ncols] - block_size tile/block size for optimal cache performance. A hardware specific parameter + Tiled (cache-blocked) implementation of the Matrix Multiplication. + Note: If only the MatMul output is needed, then please use calloc to initialize the output. + An alternative is to use malloc, followed by memset 0. + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix. + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly. + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed. + matA first matrix; shape = [nrows, ncommon]. + matB second matrix; shape = [ncommon, ncols]. + nrows number of rows in the first matrix. + ncommon number of columns in the first matrix/number of rows in the second matrix. + ncols number of columns in the second matrix. + total_comm_A the actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored. + total_cols_B the actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + ret matrix multiplication output. shape = [nrows, ncols]. + block_size tile/block size for optimal cache performance. A hardware specific parameter. */ void tiledMatMul_float(const float* const matA, const float* const matB, unsigned nrows, unsigned ncommon, unsigned ncols, @@ -84,25 +82,23 @@ void tiledMatMul_float(const float* const matA, const float* const matB, float* const ret, unsigned block_size); /* - Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format - The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage) - Note: If only the MatMul output is needed, then please use calloc to initialize the output - An alternative is to use malloc, followed by memset 0 - There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix - If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly - This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed - matA first matrix; shape = [nrows, ncommon] - matB second matrix; shape = [ncols, ncommon] - nrows number of rows in the first matrix - ncommon number of columns in the first matrix/number of rows in the second matrix - ncols number of columns in the second matrix - total_comm_A The actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored - total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. - total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. - total_comm_B The actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. - Since matB is transposed the columns are now the ncomm axis - ret matrix multiplication output. shape = [nrows, ncols] - block_size tile/block size for optimal cache performance. A hardware specific parameter + Tiled (cache-blocked) implementation of the Matrix Multiplication, but with matB stored in the transposed format. + The result will the same as the regular MatMul but the matrix B provided will be pre-transposed (before the storage or usage). + Note: If only the MatMul output is needed, then please use calloc to initialize the output. + An alternative is to use malloc, followed by memset 0. + There is second way to use this function. This is for adding the result of the MatMul to a pre-existing matrix. + If there is a pre-existing [nrows, ncols] matrix that needs to be added to the MatMul output, then pass that matrix directly. + This MatMul adds the result on the pre-existing values in ret. Hence either a zero initialized or a pre-existing mat is needed. + matA first matrix; shape = [nrows, ncommon]. + matB second matrix; shape = [ncols, ncommon]. + nrows number of rows in the first matrix. + ncommon number of columns in the first matrix/number of rows in the second matrix. + ncols number of columns in the second matrix. + total_comm_A the actual offset factor between 2 rows for matA. Used if we need fewer columns than the actual number stored. + total_comm_B the actual offset factor between 2 rows for matB. Used if we need fewer columns than the actual number stored. + since matB is transposed the columns are now the ncomm axis. + ret matrix multiplication output. shape = [nrows, ncols]. + block_size tile/block size for optimal cache performance. A hardware specific parameter. */ void transposed_tiledMatMul(const float* const matA, const float* const matB, unsigned nrows, unsigned ncommon, unsigned ncols, @@ -132,12 +128,12 @@ unsigned argmax(const float* const vec, unsigned len); // ret[i] = exp(input[i]) / \sum_i exp(input[i]) void softmax(const float* const input, unsigned len, float* const ret); -/* Custom non-linear layer for the phoneme detection model. It can be used for other time-series problems if necessary - output_signal pointer to the output signal, size = out_time * (in_channels / 2) - input_signal pointer to the input signal. size = in_time * in_channels - in_time number of input time steps +/* Custom non-linear layer for the phoneme detection model. It can be used for other time-series problems if necessary. + output_signal pointer to the output signal, size = out_time * (in_channels / 2). + input_signal pointer to the input signal. size = in_time * in_channels. + in_time number of input time steps. in_channels number of input channels. The output will have the half the number of input channels. - Necessary for in_channels % 2 == 0 + Necessary for in_channels % 2 == 0. */ void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, unsigned in_time, unsigned in_channels); diff --git a/c_reference/src/conv1d.c b/c_reference/src/conv1d.c index f918fdae0..afaa8db37 100644 --- a/c_reference/src/conv1d.c +++ b/c_reference/src/conv1d.c @@ -39,7 +39,7 @@ int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, input_signal, rank, (t_in_end - padding + 1) * in_channels, @@ -53,7 +53,7 @@ int conv1d_lr(float* output_signal, unsigned out_time, unsigned out_channels, // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, rank, (in_time + padding - t_in_start) * in_channels, @@ -112,7 +112,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha const ConvLayers_LR_Parallel_Params* tparams = (ConvLayers_LR_Parallel_Params*) params; // Perform the convolution. Zero-pad is from 0 to padding and in_time + padding to in_time + 2 * padding. - // Buffer to hold the output. For corner cases, this will be realtively big. + // Buffer to hold the output. For corner cases, this will be relatively big. // But will be needed for the central condition (filter inside input). // If there are not enough time steps to linearise into one row, then allocate only 1 time step. unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? @@ -136,7 +136,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2 + (padding - t_in_start) * in_channels, input_signal, rank, (t_in_end - padding + 1) * in_channels, kernel_size * in_channels, 1, 0, temp_rank_out); @@ -208,7 +208,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W2, input_signal + (t_in_start - padding) * in_channels, rank, (in_time + padding - t_in_start) * in_channels, @@ -225,7 +225,7 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha 0, out_channels * sizeof(float)); } } - // Bias and activation + // Bias and activation. for (t_out = 0; t_out < out_time; t_out++) { unsigned t_index = t_out * out_channels; for (unsigned co = 0; co < out_channels; co++) { @@ -287,7 +287,7 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W + (padding - t_in_start) * cols_scale, input_signal, out_channels, (t_in_end - padding + 1) * cols_scale, kernel_size * cols_scale, vec_stride, tparams->depthwise, temp_out); @@ -297,7 +297,7 @@ int conv1d(float* output_signal, unsigned out_time, unsigned out_channels, // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, out_channels, (in_time + padding - t_in_start) * cols_scale, @@ -352,7 +352,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe const ConvLayers_Parallel_Params* tparams = (ConvLayers_Parallel_Params*) params; // Perform the Convolution. Pad is from 0 to padding and in_time + padding to in_time + 2 * padding. - // Buffer to hold the output. For corner cases, this will be realtively big. + // Buffer to hold the output. For corner cases, this will be relatively big. // But will be needed for the central condition (filter inside input). // If there are not enough time steps to linearise into one row, then allocate only 1 time step. unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? @@ -371,7 +371,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe // Filter partially entered the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W + (padding - t_in_start) * in_channels, input_signal, out_channels, (t_in_end - padding + 1) * in_channels, ncols, 1, 0, temp_out); @@ -433,7 +433,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe // Filter partially exited the input. // As a part of the filter is outside the input, we need less than "kernel_size" time-steps. // We will only be using a part of the weight matrix(assuming shape = out_channels, kernel_size * in_channels). - // Hence we provide a separate row_stride paramemter to discard/skip certain columns in the weight matrix. + // Hence we provide a separate row_stride parameter to discard/skip certain columns in the weight matrix. offset_matVec_conv1d(tparams->W, input_signal + (t_in_start - padding) * in_channels, out_channels, (in_time + padding - t_in_start) * in_channels, @@ -447,7 +447,7 @@ int conv1d_parallel(float* output_signal, unsigned out_time, unsigned out_channe 0, out_channels * sizeof(float)); } } - // Bias and activation + // Bias and activation. for (t_out = 0; t_out < out_time; t_out++) { unsigned t_index = t_out * out_channels; for (unsigned co = 0; co < out_channels; co++) { diff --git a/c_reference/src/rnn_bricked.c b/c_reference/src/rnn_bricked.c index e2da02995..2bfcc5635 100644 --- a/c_reference/src/rnn_bricked.c +++ b/c_reference/src/rnn_bricked.c @@ -164,7 +164,7 @@ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, unsigned rnn_assign_offset = rnn_hidden; unsigned num_bricks = (in_time - window) / hop + 1; - unsigned out_index = in_time / hop; // = out_time - 1; + unsigned out_index = in_time / hop; // If bi-directional is True(non-zero) then the actual output hidden state(allocated space) is twice rnn_hidden. // This function only processes the forward context. if (bi_direction) { diff --git a/c_reference/src/utils.c b/c_reference/src/utils.c index b9b9edaa3..259b1d85a 100644 --- a/c_reference/src/utils.c +++ b/c_reference/src/utils.c @@ -86,7 +86,7 @@ void offset_matVec_conv1d(const float* mat, const float* vec, #ifdef LOOP_UNROLL unsigned len_unroll = cols >> 2; - cols %= 4; // ncols % 4 + cols %= 4; // ncols % 4. while (len_unroll--) { sum += (*mat_offset++) * (*vec_offset); vec_offset += vec_stride; @@ -245,7 +245,7 @@ void softmax(const float* const input, unsigned len, float* const ret) { void semi_sigmoid_tanh(float* output_signal, const float* const input_signal, unsigned in_time, unsigned in_channels) { - unsigned time_step = 0; // used to avoid index multiplication. + unsigned time_step = 0; // Used to avoid index multiplication. while (in_time--) { unsigned pivot = in_channels >> 1; float* input_sigmoid_offset = (float*)input_signal + time_step; diff --git a/c_reference/tests/conv1d/test_conv1d.c b/c_reference/tests/conv1d/test_conv1d.c index 969ebb3bc..d42aea786 100644 --- a/c_reference/tests/conv1d/test_conv1d.c +++ b/c_reference/tests/conv1d/test_conv1d.c @@ -10,7 +10,7 @@ #include "./conv1d_depthwise/conv_param_depth.h" #include "./conv1d_lr/conv_param_lr.h" -// Error Check +// Error Check. void errorCheck(float* pred, float* label, unsigned out_time, int out_features) { float error = 0, denom = 0; for (unsigned t = 0; t < out_time; t++) { diff --git a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c index 18104d0e8..c9cbc6658 100644 --- a/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c +++ b/c_reference/tests/kws/test_phoneme_det_cnn_rnn.c @@ -23,7 +23,8 @@ int checkTime(unsigned out_time) { } return 0; } -// Error Check + +// Error Check. void checkError(float* pred, float* label) { float error = 0, denom = 0; for (unsigned t = 0; t < KWS_OUT_TIME; t++) { @@ -51,7 +52,7 @@ void checkError(float* pred, float* label) { 2nd block is an RNN, which has a specified forward and a backward context running at a stride/hop of 3. Hence it reduces the sequence length by a factor of 3. Rest of the blocks(3rd, 4th, 5th and 6th) are a combination of CNNs. - Each of the final 4 blocks consist of a depth cnn (kernel size of 5) and a point cnn (kernel size of 1). + Each of the final 4 blocks consist of a depth-CNN (kernel size of 5) and a point-CNN (kernel size of 1). Input to the architecture is of the form (seq_len, feature_dim) where feature dim refers to n_mels (number of mel features/number of features from the featurizer). Output is of the form (seq_len/3, 41) where 41 is the number of phonemes over which the classification is performed. @@ -166,7 +167,7 @@ void phoneme_prediction(float* mem_buf) { unsigned in_time, out_time; - /* Pre-CNN */ + /* Pre-CNN. */ in_time = KWS_IN_TIME; out_time = in_time - PRE_CNN_FILT + (PRE_CNN_FILT_PAD << 1) + 1; float* cnn1_out = (float*)malloc(out_time * PRE_CNN_OUT_FEATURES * sizeof(float)); @@ -176,12 +177,12 @@ void phoneme_prediction(float* mem_buf) { conv1d_lr_parallel, in_time, PRE_CNN_IN_FEATURES, 0, 0, PRE_CNN_BNORM_AFFINE, CNN1_SCALE, CNN1_OFFSET, PRE_CNN_BNORM_INPLACE, PRE_CNN_OUT_FEATURES, PRE_CNN_FILT_PAD, PRE_CNN_FILT, - &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // regular tanh activation. + &conv_params, PRE_CNN_STRIDE, PRE_CNN_FILT_ACT); // Regular tanh activation. batchnorm1d(0, cnn1_out, in_time, RNN_IN_FEATURES, 0, 0, RNN_BNORM_AFFINE, RNN_SCALE, RNN_OFFSET, 1, 0.00001); - /* Bricked Bi-FastGRNN Block */ + /* Bricked Bi-FastGRNN Block. */ out_time = in_time/RNN_HOP + 1; float* rnn_out = (float*)malloc(out_time * RNN_OUT_FEATURES * sizeof(float)); forward_bricked_fastgrnn_lr(rnn_out, RNN_OUT_FEATURES >> 1, cnn1_out, @@ -194,7 +195,7 @@ void phoneme_prediction(float* mem_buf) { &bwd_RNN_params, RNN_BI_DIR, RNN_SAMPLE_LAST_BRICK); free(cnn1_out); - /* Post-CNN */ + /* Post-CNN. */ // Since all inputs to the subsequent layers are temporary, in-place batchnorm1d can be used without any input(initial buffer)/output(final layer) data alteration/corruption. // CNN2. in_time = out_time; @@ -256,7 +257,7 @@ void phoneme_prediction(float* mem_buf) { POST_CNN_POOL_PAD, POST_CNN_POOL, POST_CNN_POOL_STRIDE, POST_CNN_POOL_ACT); free(cnn4_out); - /* Output Time and Prediction Check. Created for Debugging */ + /* Output Time and Prediction Check. Created for Debugging. */ if (checkTime(out_time)) return; else From 6cd964d68e6629e5e88e4888f19eb9be5c61482b Mon Sep 17 00:00:00 2001 From: Anirudh0707 Date: Wed, 20 Oct 2021 17:12:47 -0700 Subject: [PATCH 4/4] Adding a period to the comments --- c_reference/src/conv1d.c | 4 ++-- c_reference/src/rnn_bricked.c | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/c_reference/src/conv1d.c b/c_reference/src/conv1d.c index afaa8db37..552abb78a 100644 --- a/c_reference/src/conv1d.c +++ b/c_reference/src/conv1d.c @@ -118,9 +118,9 @@ int conv1d_lr_parallel(float* output_signal, unsigned out_time, unsigned out_cha unsigned buffer_steps = ((in_time / num_steps_one_row) > 1) ? in_time / num_steps_one_row : 1; unsigned rank = tparams->rank; - // Buffer for W2 out + // Buffer for W2 out. float* temp_rank_out = (float*)malloc(buffer_steps * rank * sizeof(float)); - // Buffer for W1 out + // Buffer for W1 out. float* temp_out = (float*)malloc(buffer_steps * out_channels * sizeof(float)); unsigned t_in_start, t_in_end, t_out; // Values are needed outside the loops. Hence declared here. diff --git a/c_reference/src/rnn_bricked.c b/c_reference/src/rnn_bricked.c index 2bfcc5635..c09f7cf90 100644 --- a/c_reference/src/rnn_bricked.c +++ b/c_reference/src/rnn_bricked.c @@ -191,7 +191,7 @@ int backward_bricked_fastgrnn_lr(float* output_signal, unsigned rnn_hidden, transposed_tiledMatMul(hiddenState, tparams->U1, num_bricks, rnn_hidden, tparams->uRank, rnn_hidden, rnn_hidden, tempLR, tparams->block_size_u_to_lr); - // From lower dims to higher dims. + // From lower dims to higher dims. // Add Wx with Uh. // The tiled MatMuls are codes such that they yield result += matA * matB. // Hence we use calloc and memset to equate the result to 0.