-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval.py
72 lines (60 loc) · 2.6 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
import argparse
import pytorch_lightning as pl
from safetensors import safe_open
from main import process_deprecated
from pytorch_lightning.loggers import WandbLogger
from custom_data import NbhoodDataModule
from kgt5_model import KGT5_Model
from omegaconf import DictConfig, OmegaConf, open_dict
def run(checkpoint_path: str, config_path: str, split: str) -> None:
config = OmegaConf.load(config_path)
print(OmegaConf.to_yaml(config))
if ".hydra" in config_path:
config.output_dir = config_path.split(".hydra")[0]
else:
config.output_dir = config_path.split("config.yaml")[0]
print("output written to", config.output_dir)
config = process_deprecated(config)
dm = NbhoodDataModule(config=config)
if checkpoint_path.endswith("model.safetensors"): # huggingface model with safetensors
model = KGT5_Model(config, data_module=dm)
tensors = {}
with safe_open(checkpoint_path, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
model.load_state_dict(tensors, strict=False) # strict false due to shared weights in T5
elif checkpoint_path.endswith("pytorch_model.bin"): # huggingface model
checkpoint_path = os.path.dirname(checkpoint_path)
model = KGT5_Model(config, data_module=dm)
model = model.from_pretrained(checkpoint_path, local_files_only=True, config=config, data_module=dm)
else:
model = KGT5_Model.load_from_checkpoint(
checkpoint_path, config=config, data_module=dm
)
train_options = {
'accelerator': config.train.accelerator,
'devices': 1,
'max_epochs': 1,
'default_root_dir': config.output_dir,
'strategy': config.train.strategy,
'precision': config.train.precision,
'check_val_every_n_epoch': config.valid.every,
}
trainer = pl.Trainer(**train_options, ) # limit_train_batches=1)
if split=="test":
trainer.test(model, datamodule=dm)
else:
trainer.validate(model, datamodule=dm)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test kgt5 model')
parser.add_argument('-c', '--config', help='Path to config',
required=True)
parser.add_argument('-m', '--model', help='Path to checkpoint',
required=True)
parser.add_argument('-s', '--split', help='Split to evaluate on',
default="test")
args = vars(parser.parse_args())
torch.set_float32_matmul_precision('medium')
run(args["model"], args["config"], split=args["split"])