Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_BaseGRF.oob_predict does not take use of multi-core processing #940

Open
winston-zillow opened this issue Dec 20, 2024 · 0 comments
Open

Comments

@winston-zillow
Copy link

I found despite setting n_jobs to -1 or number of CPU cores, the CausalForestDML is still very slow to train. It turns out that it gets stuck in _BaseGRF.oob_predict() since this method is using threading joblib backend and cannot take advantage of multi-core.

I can fix it with the following:

def oob_predict(self, Xtrain: np.ndarray):
    ...

    # Parallel loop
    ## ORIGINAL CODE SNIPPET responsible for the sluggish
    # lock = threading.Lock()
    # Parallel(n_jobs=self.n_jobs, verbose=self.verbose, backend='threading', require="sharedmem")(
    #     delayed(_accumulate_oob_preds)(tree, Xtrain, sinds, alpha_hat, jac_hat, counts, lock)
    #     for tree, sinds in zip(self.estimators_, subsample_inds))

    temp_folder = tempfile.mkdtemp()
    filename = os.path.join(temp_folder, 'joblib_test.mmap')
    try:
        if os.path.exists(filename):
            os.unlink(filename)
        filename = os.path.join(temp_folder, 'joblib_test.mmap')
        # WARNING: this is unfortunate. Xtrain.dtype == `object` which can't be serialized; for us all cols are int/float/bool
        _X = Xtrain.astype(np.float32)
        _X.tofile(filename)
        X_memmap = np.memmap(filename, dtype=_X.dtype, mode='r', shape=_X.shape)

        def _accumulate_oob_preds_fast(tree, subsample_inds):
            nonlocal X_memmap
            mask = np.ones(X_memmap.shape[0], dtype=bool)
            mask[subsample_inds] = False
            alpha, jac = tree.predict_alpha_and_jac(X_memmap[mask])
            return mask, alpha, jac, os.getpid()

        job = Parallel(n_jobs=self.n_jobs, backend='loky', return_as='generator')
        for mask, alpha, jac, pid in job(
                delayed(_accumulate_oob_preds_fast)(tree, sinds)
                for tree, sinds in zip(self.estimators_, subsample_inds)):
            alpha_hat[mask] += alpha
            jac_hat[mask] += jac
            counts[mask] += 1
    finally:
        if os.path.exists(filename):
            os.unlink(filename)

Note that memory mapping of the large Xtrain is required for takimg advantage of all cores; else it still runs on 3-4 cores concurrently only if Xtrain is passed via nonlocal reference. However, this unfortunate requires Xtrain.astype(np.float32) for memory-mapping the numpy array. So there may need other changes to this method or caller for a general fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant