Skip to content

PTPCG DuEE-Fin Task Dump

Latest
Compare
Choose a tag to compare
@Spico197 Spico197 released this 16 Jun 07:34
· 2 commits to main since this release

Evaluation code snippet:

import os

from dee.helper import sent_seg
from dee.tasks import DEETask, DEETaskSetting

if __name__ == "__main__":
    # # init
    # task_dir = "Exps/sct-Tp1CG-with_left_trigger-OtherType-comp_ents-bs64_8"
    # cpt_file_name = "TriggerAwarePrunedCompleteGraph"
    # # bert_model_dir is for tokenization use, `vocab.txt` must be included in this dir
    # # change this to `bert-base-chinese` to use the huggingface online cache
    # bert_model_dir = "/data4/tzhu/pretrained_model/bert-base-chinese"

    # # load settings
    # dee_setting = DEETaskSetting.from_pretrained(
    #     os.path.join(task_dir, f"{cpt_file_name}.task_setting.json")
    # )
    # dee_setting.local_rank = -1
    # dee_setting.filtered_data_types = "o2o,o2m,m2m,unk"
    # dee_setting.bert_model = bert_model_dir

    # # build task
    # dee_task = DEETask(
    #     dee_setting,
    #     load_train=False,
    #     load_dev=False,
    #     load_test=True,
    #     load_inference=False,
    #     parallel_decorate=False,
    # )

    # # load PTPCG parameters
    # dee_task.resume_cpt_at(57)

    # dee_task.eval(
    #     dee_task.test_features,
    #     dee_task.test_dataset,
    #     use_gold_span=False,
    #     heuristic_type=None,
    #     dump_decode_pkl_name="rep.test.pred_span.57.pkl",
    #     dump_eval_json_name="rep.test.pred_span.57.json",
    # )

    # init
    task_dir = "Exps/sct-Tp1CG-luge_without_trigger"
    cpt_file_name = "TriggerAwarePrunedCompleteGraph"
    # bert_model_dir is for tokenization use, `vocab.txt` must be included in this dir
    # change this to `bert-base-chinese` to use the huggingface online cache
    bert_model_dir = "/data4/tzhu/pretrained_model/bert-base-chinese"

    # load settings
    dee_setting = DEETaskSetting.from_pretrained(
        os.path.join(task_dir, f"{cpt_file_name}.task_setting.json")
    )
    dee_setting.local_rank = -1
    dee_setting.filtered_data_types = "o2o,o2m,m2m,unk"
    dee_setting.bert_model = bert_model_dir

    # build task
    dee_task = DEETask(
        dee_setting,
        load_train=False,
        load_dev=True,
        load_test=False,
        load_inference=False,
        parallel_decorate=False,
    )

    # load PTPCG parameters
    dee_task.resume_cpt_at(99)

    dee_task.eval(
        dee_task.dev_features,
        dee_task.dev_dataset,
        use_gold_span=False,
        heuristic_type=None,
        dump_decode_pkl_name="rep.dev.pred_span.99.pkl",
        dump_eval_json_name="rep.dev.pred_span.99.json",
    )