Skip to content

Commit

Permalink
fix for interpolation parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael McCrackan committed Dec 17, 2024
1 parent e30cca2 commit 9717ec2
Showing 1 changed file with 24 additions and 32 deletions.
56 changes: 24 additions & 32 deletions src/array_ops.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -894,26 +894,22 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
int y_data_stride = y_buf->strides[0] / sizeof(double);
int y_interp_data_stride = y_interp_buf->strides[0] / sizeof(double);

#pragma omp parallel
{
#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {
// Create one accel and spline per thread
gsl_interp_accel* acc = gsl_interp_accel_alloc();
gsl_spline* spline = gsl_spline_alloc(interp_type, n_x);

#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {
int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;

int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;
T* y_row = y_data + y_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;

T* y_row = y_data + y_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;
interp_func(x_data, y_row, x_interp_data, y_interp_row,
n_x, n_x_interp, spline, acc);

interp_func(x_data, y_row, x_interp_data, y_interp_row,
n_x, n_x_interp, spline, acc);
}

// Free gsl objects
gsl_spline_free(spline);
gsl_interp_accel_free(acc);
Expand All @@ -933,31 +929,27 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
std::transform(x_interp_data, x_interp_data + n_x_interp, x_interp_dbl,
[](float value) { return static_cast<double>(value); });

#pragma omp parallel
{
#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {
// Create one accel and spline per thread
gsl_interp_accel* acc = gsl_interp_accel_alloc();
gsl_spline* spline = gsl_spline_alloc(interp_type, n_x);

#pragma omp parallel for
for (int row = 0; row < n_rows; ++row) {

int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;
int y_row_start = row * y_data_stride;
int y_row_end = y_row_start + n_x;
int y_interp_row_start = row * y_interp_data_stride;

// Transform y row to double array for gsl
double y_dbl[n_x];
// Transform y row to double array for gsl
double y_dbl[n_x];

std::transform(y_data + y_row_start, y_data + y_row_end, y_dbl,
[](float value) { return static_cast<double>(value); });
std::transform(y_data + y_row_start, y_data + y_row_end, y_dbl,
[](float value) { return static_cast<double>(value); });

T* y_interp_row = y_interp_data + y_interp_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;

// Don't copy y_interp to doubles as it is cast during assignment
interp_func(x_dbl, y_dbl, x_interp_dbl, y_interp_row,
n_x, n_x_interp, spline, acc);
}
// Don't copy y_interp to doubles as it is cast during assignment
interp_func(x_dbl, y_dbl, x_interp_dbl, y_interp_row,
n_x, n_x_interp, spline, acc);

// Free gsl objects
gsl_spline_free(spline);
Expand All @@ -977,15 +969,15 @@ void interp1d_linear(const bp::object & x, const bp::object & y,
const gsl_interp_type* interp_type = gsl_interp_linear;
// Pointer to interpolation function
_interp_func_pointer<float> interp_func = &_linear_interp<float>;

_interp1d<float>(x, y, x_interp, y_interp, interp_type, interp_func);
}
else if (dtype == NPY_DOUBLE) {
// GSL interpolation type
const gsl_interp_type* interp_type = gsl_interp_linear;
// Pointer to interpolation function
_interp_func_pointer<double> interp_func = &_linear_interp<double>;

_interp1d<double>(x, y, x_interp, y_interp, interp_type, interp_func);
}
else {
Expand Down

0 comments on commit 9717ec2

Please sign in to comment.