Evaluation code snippet:
import os
from dee.helper import sent_seg
from dee.tasks import DEETask, DEETaskSetting
if __name__ == "__main__":
# # init
# task_dir = "Exps/sct-Tp1CG-with_left_trigger-OtherType-comp_ents-bs64_8"
# cpt_file_name = "TriggerAwarePrunedCompleteGraph"
# # bert_model_dir is for tokenization use, `vocab.txt` must be included in this dir
# # change this to `bert-base-chinese` to use the huggingface online cache
# bert_model_dir = "/data4/tzhu/pretrained_model/bert-base-chinese"
# # load settings
# dee_setting = DEETaskSetting.from_pretrained(
# os.path.join(task_dir, f"{cpt_file_name}.task_setting.json")
# )
# dee_setting.local_rank = -1
# dee_setting.filtered_data_types = "o2o,o2m,m2m,unk"
# dee_setting.bert_model = bert_model_dir
# # build task
# dee_task = DEETask(
# dee_setting,
# load_train=False,
# load_dev=False,
# load_test=True,
# load_inference=False,
# parallel_decorate=False,
# )
# # load PTPCG parameters
# dee_task.resume_cpt_at(57)
# dee_task.eval(
# dee_task.test_features,
# dee_task.test_dataset,
# use_gold_span=False,
# heuristic_type=None,
# dump_decode_pkl_name="rep.test.pred_span.57.pkl",
# dump_eval_json_name="rep.test.pred_span.57.json",
# )
# init
task_dir = "Exps/sct-Tp1CG-luge_without_trigger"
cpt_file_name = "TriggerAwarePrunedCompleteGraph"
# bert_model_dir is for tokenization use, `vocab.txt` must be included in this dir
# change this to `bert-base-chinese` to use the huggingface online cache
bert_model_dir = "/data4/tzhu/pretrained_model/bert-base-chinese"
# load settings
dee_setting = DEETaskSetting.from_pretrained(
os.path.join(task_dir, f"{cpt_file_name}.task_setting.json")
)
dee_setting.local_rank = -1
dee_setting.filtered_data_types = "o2o,o2m,m2m,unk"
dee_setting.bert_model = bert_model_dir
# build task
dee_task = DEETask(
dee_setting,
load_train=False,
load_dev=True,
load_test=False,
load_inference=False,
parallel_decorate=False,
)
# load PTPCG parameters
dee_task.resume_cpt_at(99)
dee_task.eval(
dee_task.dev_features,
dee_task.dev_dataset,
use_gold_span=False,
heuristic_type=None,
dump_decode_pkl_name="rep.dev.pred_span.99.pkl",
dump_eval_json_name="rep.dev.pred_span.99.json",
)