Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slight fix for GSL interpolation parallelization #195

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 24 additions & 32 deletions src/array_ops.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -894,26 +894,22 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
int y_data_stride = y_buf->strides[0] / sizeof(double);
int y_interp_data_stride = y_interp_buf->strides[0] / sizeof(double);

#pragma omp parallel
{
#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {
// Create one accel and spline per thread
gsl_interp_accel* acc = gsl_interp_accel_alloc();
gsl_spline* spline = gsl_spline_alloc(interp_type, n_x);

#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {
int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;

int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;
T* y_row = y_data + y_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;

T* y_row = y_data + y_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;
interp_func(x_data, y_row, x_interp_data, y_interp_row,
n_x, n_x_interp, spline, acc);

interp_func(x_data, y_row, x_interp_data, y_interp_row,
n_x, n_x_interp, spline, acc);
}

// Free gsl objects
gsl_spline_free(spline);
gsl_interp_accel_free(acc);
Expand All @@ -933,31 +929,27 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
std::transform(x_interp_data, x_interp_data + n_x_interp, x_interp_dbl,
[](float value) { return static_cast<double>(value); });

#pragma omp parallel
{
#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {
// Create one accel and spline per thread
gsl_interp_accel* acc = gsl_interp_accel_alloc();
gsl_spline* spline = gsl_spline_alloc(interp_type, n_x);

#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {

int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;
int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;

// Transform y row to double array for gsl
double y_dbl[n_x];
// Transform y row to double array for gsl
double y_dbl[n_x];

std::transform(y_data + y_row_start, y_data + y_row_end, y_dbl,
[](float value) { return static_cast<double>(value); });
std::transform(y_data + y_row_start, y_data + y_row_end, y_dbl,
[](float value) { return static_cast<double>(value); });

T* y_interp_row = y_interp_data + y_interp_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;

// Don't copy y_interp to doubles as it is cast during assignment
interp_func(x_dbl, y_dbl, x_interp_dbl, y_interp_row,
n_x, n_x_interp, spline, acc);
}
// Don't copy y_interp to doubles as it is cast during assignment
interp_func(x_dbl, y_dbl, x_interp_dbl, y_interp_row,
n_x, n_x_interp, spline, acc);

// Free gsl objects
gsl_spline_free(spline);
Expand All @@ -977,15 +969,15 @@ void interp1d_linear(const bp::object & x, const bp::object & y,
const gsl_interp_type* interp_type = gsl_interp_linear;
// Pointer to interpolation function
_interp_func_pointer<float> interp_func = &_linear_interp<float>;

_interp1d<float>(x, y, x_interp, y_interp, interp_type, interp_func);
}
else if (dtype == NPY_DOUBLE) {
// GSL interpolation type
const gsl_interp_type* interp_type = gsl_interp_linear;
// Pointer to interpolation function
_interp_func_pointer<double> interp_func = &_linear_interp<double>;

_interp1d<double>(x, y, x_interp, y_interp, interp_type, interp_func);
}
else {
Expand Down
Loading