Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickle error when saving the lr scheduler defined after fabric.setup or by LambdaLR in Fabric #18688

Closed
hiyyg opened this issue Oct 2, 2023 · 10 comments
Labels
bug Something isn't working fabric lightning.fabric.Fabric repro needed The issue is missing a reproducible example ver: 2.0.x

Comments

@hiyyg
Copy link

hiyyg commented Oct 2, 2023

Bug description

_pickle.PicklingError: Can't pickle <class 'lightning.fabric.wrappers.FabricRanger'>: attribute lookup FabricRanger on lightning.fabric.wrappers failed

What version are you seeing the problem on?

v2.0

How to reproduce the bug

I followed this way to save a Ranger optimizer.

# Define the state of your program/loop
state = {"model1": model1, "model2": model2, "optimizer": optimizer, "iteration": iteration, "hparams": ...}

fabric.save("path/to/checkpoint.ckpt", state)


### Error messages and logs

Error messages and logs here please

_pickle.PicklingError: Can't pickle <class 'lightning.fabric.wrappers.FabricRanger'>: attribute lookup FabricRanger on lightning.fabric.wrappers failed



### Environment

<details>
  <summary>Current environment</summary>

#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(conda, pip, source):
#- Running environment of LightningApp (e.g. local, cloud):


</details>


### More info

_No response_

cc @carmocca @justusschock @awaelchli
@hiyyg hiyyg added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 2, 2023
@hiyyg
Copy link
Author

hiyyg commented Oct 2, 2023

The same for AdamW:

_pickle.PicklingError: Can't pickle <class 'lightning.fabric.wrappers.FabricAdamW'>: attribute lookup FabricAdamW on lightning.fabric.wrappers failed

@awaelchli
Copy link
Contributor

awaelchli commented Oct 2, 2023

@hiyyg Is this the implementation for Ranger you are referring to?

I tried to reproduce your error but couldn't. Here is what I tried (by copying the ranger definition):

import math
import torch
from torch.optim.optimizer import Optimizer

from lightning import Fabric


class Ranger(Optimizer):

    def __init__(self, params, lr=1e-3,                       # lr
                 alpha=0.5, k=6, N_sma_threshhold=5,           # Ranger options
                 betas=(.95, 0.999), eps=1e-5, weight_decay=0,  # Adam options
                 # Gradient centralization on or off, applied to conv layers only or conv + fc layers
                 use_gc=True, gc_conv_only=False
                 ):

        # parameter checks
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        if not lr > 0:
            raise ValueError(f'Invalid Learning Rate: {lr}')
        if not eps > 0:
            raise ValueError(f'Invalid eps: {eps}')

        # parameter comments:
        # beta1 (momentum) of .95 seems to work better than .90...
        # N_sma_threshold of 5 seems better in testing than 4.
        # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.

        # prep defaults and init torch.optim base
        defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas,
                        N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)

        # adjustable threshold
        self.N_sma_threshhold = N_sma_threshhold

        # look ahead params

        self.alpha = alpha
        self.k = k

        # radam buffer for state
        self.radam_buffer = [[None, None, None] for ind in range(10)]

        # gc on or off
        self.use_gc = use_gc

        # level of gradient centralization
        self.gc_gradient_threshold = 3 if gc_conv_only else 1

        print(
            f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}")
        if (self.use_gc and self.gc_gradient_threshold == 1):
            print(f"GC applied to both conv and fc layers")
        elif (self.use_gc and self.gc_gradient_threshold == 3):
            print(f"GC applied to conv layers only")

    def __setstate__(self, state):
        print("set state called")
        super(Ranger, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
        # Uncomment if you need to use the actual closure...

        # if closure is not None:
        #loss = closure()

        # Evaluate averages and grad, update param tensors
        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()

                if grad.is_sparse:
                    raise RuntimeError(
                        'Ranger optimizer does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]  # get state dict for this param

                if len(state) == 0:  # if first time to run...init dictionary with our desired entries
                    # if self.first_run_check==0:
                    # self.first_run_check=1
                    #print("Initializing slow buffer...should not see this at load from saved model!")
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)

                    # look ahead weight storage now in state dict
                    state['slow_buffer'] = torch.empty_like(p.data)
                    state['slow_buffer'].copy_(p.data)

                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
                        p_data_fp32)

                # begin computations
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                # GC operation for Conv layers and FC layers
                if grad.dim() > self.gc_gradient_threshold:
                    grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))

                state['step'] += 1

                # compute variance mov avg
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                # compute mean moving avg
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                buffered = self.radam_buffer[int(state['step'] % 10)]

                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * \
                        state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma
                    if N_sma > self.N_sma_threshhold:
                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (
                            N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = 1.0 / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay']
                                     * group['lr'], p_data_fp32)

                # apply lr
                if N_sma > self.N_sma_threshhold:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size *
                                         group['lr'], exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)

                p.data.copy_(p_data_fp32)

                # integrated look ahead...
                # we do it at the param level instead of group level
                if state['step'] % group['k'] == 0:
                    # get access to slow param tensor
                    slow_p = state['slow_buffer']
                    # (fast weights - slow weights) * alpha
                    slow_p.add_(self.alpha, p.data - slow_p)
                    # copy interpolated weights to RAdam param tensor
                    p.data.copy_(slow_p)

        return loss



def main():

    fabric = Fabric(accelerator="cpu")

    model = torch.nn.Linear(2, 2)
    optimizer = Ranger(model.parameters())

    model, optimizer = fabric.setup(model, optimizer)

    model(torch.randn(2, 2)).sum().backward()
    optimizer.step()

    state = {"model1": model, "optimizer": optimizer}

    fabric.save("state.pt", state)



if __name__ == "__main__":
    main()

Could you please provide a code example that we can study? Feel free to modify my example here. Thanks

@awaelchli awaelchli added fabric lightning.fabric.Fabric repro needed The issue is missing a reproducible example and removed needs triage Waiting to be triaged by maintainers labels Oct 2, 2023
@awaelchli awaelchli changed the title fabric failed to save optimizer Pickle error when saving the Ranger optimizer state in Fabric Oct 2, 2023
@hiyyg
Copy link
Author

hiyyg commented Oct 3, 2023

@awaelchli This is not an error with Ranger. The problem is that I defined an lr scheduler after fabric.setup(model, optimizer). If I define the scheduler before fabric.setup, the error will be gone.

@hiyyg
Copy link
Author

hiyyg commented Oct 3, 2023

Another problem is that, if the scheduler is defined by torch.optim.lr_scheduler.LambdaLR, such as

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / 1000) ** 0.9)

state = {"scheduler": scheduler}

trying to fabric.save('state.pt', state) will raise the following error:

AttributeError: Can't pickle local object 'main.<locals>.<lambda>'

So I have to change the code like:

state = {"scheduler": scheduler.state_dict()}
fabric.save('state.pt', state)

fabric.load("state.pt", state)
scheduler.load_state_dict(state['scheduler'])

In this way, I have to manually do load_state_dict, but it seems this is the only way to bypass the error.

I wonder whether it is always safe to do save/load like this for all objects that have .state_dict() and .load_state_dict()? Because directly saving the objects is buggy now. Also mixing objects and state_dicts is not a good idea.

@hiyyg hiyyg changed the title Pickle error when saving the Ranger optimizer state in Fabric Pickle error when saving the scheduler defined after fabric.setup or by LambdaLR in Fabric Oct 3, 2023
@hiyyg hiyyg changed the title Pickle error when saving the scheduler defined after fabric.setup or by LambdaLR in Fabric Pickle error when saving the lr scheduler defined after fabric.setup or by LambdaLR in Fabric Oct 3, 2023
@hiyyg
Copy link
Author

hiyyg commented Oct 3, 2023

It seems another way is to use:

state = {"scheduler": scheduler}
save_state = {"scheduler": scheduler.state_dict()}
fabric.save('state.pt', save_state)

fabric.load("state.pt", state)

Could this way safely load the state_dicts to the objects?

@hiyyg
Copy link
Author

hiyyg commented Oct 3, 2023

@awaelchli
Copy link
Contributor

@hiyyg We added support for saving and loading stateful objects automatically here: #18513

But yes, the user can always decide to do .state_dict() and .load_state_dict() themselves. Fabric remains flexible here. Modules, schedulers and optimizers are not pickleable in general, so saving the whole object is never recommended. With #18513, users won't need to remember this and can rely on Fabric handling the state dicts. Does this cover your use case?

@hiyyg
Copy link
Author

hiyyg commented Oct 3, 2023

Is this not released yet? Should I install from source to have this feature?

@awaelchli
Copy link
Contributor

Yes, you'd have to install from source to get this feature. But it's only a nice-to-have. As I wrote, and as you figured out yourself, doing the .state_dict() and .load_state_dict() yourself can also be done and would be equivalent (you'd do this in raw PyTorch too).

So feel free to install from source and test it out, but it's not mandatory :)

@hiyyg
Copy link
Author

hiyyg commented Oct 3, 2023

It seems another way is to use:

state = {"scheduler": scheduler}
save_state = {"scheduler": scheduler.state_dict()}
fabric.save('state.pt', save_state)

fabric.load("state.pt", state)

Could this way safely load the state_dicts to the objects?

Thanks.

It seems the above method can remove the need for .load_state_dict(), but need to manually do .state_dict() just for saving.

@hiyyg hiyyg closed this as completed Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric repro needed The issue is missing a reproducible example ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

2 participants