-
Notifications
You must be signed in to change notification settings - Fork 1
/
gradient_checkpointing.py
61 lines (49 loc) · 1.85 KB
/
gradient_checkpointing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from torch.utils.checkpoint import checkpoint
from torch.autograd import Variable
import torch
from torch import nn
import numpy as np
class NewForward:
def __init__(self, layer):
self.layer = layer
self.apply_patch()
def apply_patch(self):
self.layer.old_forward_for_cp = self.layer.forward
self.layer.forward = self.new_forward
def new_forward(self, *args, **kwargs):
def func(*args):
return self.layer.old_forward_for_cp(*args, **kwargs)
output = checkpoint(func, *args)
return output
class VarWrapper:
def __init__(self, model):
self.model = model
self.apply_patch()
print('Var Wrapper Patch Applied')
def apply_patch(self):
self.model.old_forward_for_cp = self.model.forward
self.model.forward = self.new_forward
def new_forward(self, *args, **kwargs):
out = self.model.old_forward_for_cp(*args, **kwargs)
out = Variable(out.data, requires_grad=True)
return out
def apply_gradient_checkpointing(model, checkpoint_ratio=1):
new_forwards = []
modules = []
for n, m in model.named_modules():
if isinstance(m, LlamaDecoderLayer):
modules.append(m)
if checkpoint_ratio < 1 and checkpoint_ratio > 0:
checkpoint_locs = np.array((np.linspace(0, 1, int(len(modules) * checkpoint_ratio)) * (len(modules)-1)).round(), dtype=int)
else:
checkpoint_locs = np.arange(len(modules))
for i in checkpoint_locs:
m = modules[i]
new_forwards.append(NewForward(m))
print('Forward Patch Applied For Block {}'.format(i))
for n, m in model.named_modules():
if isinstance(m, torch.nn.Embedding):
wrapper = VarWrapper(m)
break
return new_forwards, wrapper