-
Notifications
You must be signed in to change notification settings - Fork 2
Data_Processing
- MRC + retrieval 모두에서 크게 필요없다고 생각되는 개행문자("\n", "\\n"), 정답에 포함되어 있지 않는 특수문자, "#"등을 전처리 하고자 하였다.
- wiki data를 그대로 가져온것이라고 생각 되었고 preprocessing이 필요하다고 판단
def preprocess(text):
text = re.sub(r'\n', ' ', text)
text = re.sub(r"\\n", " ", text)
text = re.sub(r'#', ' ', text)
text = re.sub(r"\s+", " ", text)
text = re.sub(r"[^a-zA-Z0-9가-힣ㄱ-ㅎㅏ-ㅣぁ-ゔァ-ヴー々〆〤一-龥<>()\s\.\?!》《≪≫\'<>〈〉:‘’%,『』「」<>・\"-“”∧]", "", text)
return text
def run_preprocess(data_dict):
context = data_dict["context"]
start_ids = data_dict["answers"]["answer_start"][0]
before = data_dict["context"][:start_ids]
after = data_dict["context"][start_ids:]
process_before = preprocess(before)
process_after = preprocess(after)
process_data = process_before + process_after
ids_move = len(before) - len(process_before)
data_dict["context"] = process_data
data_dict["answers"]["answer_start"][0] = start_ids - ids_move
return data_dict
def check(data_list):
for data in data_list:
start_ids = data["answers"]["answer_start"][0]
end_ids = start_ids + len(data["answers"]["text"][0])
if data["answers"]["text"][0] != data["context"][start_ids : end_ids]:
print("wrong")
return
print("good")
train_data = load_from_disk("train data path")["train"]
new_train_data = []
for data in train_data:
new_data = run_preprocess(data)
new_train_data.append(new_data)
check(new_train_data)
preprocess
- 개행문자 "\n", "\\n" 를 " "으로 대체
- "#"을 " "으로 대체
- 개행문자, "#"를 제외하면서 " "이 두번이상 반복될 시 " " 한번으로 대체
- 영어, 한글, 한문, 일본어 + 정답에 존재하는 특수문자를 제외한 모든 문자들을 " "으로 대체
run preprocess
- 전처리 + 전처리후 answer의 start index 찾기
- answer의 start index 기준으로 before, after로 분리하여 preprocess
- before의 길이와 process_before의 길이를 비교하여 process data의 start index 정의
Before
after
-
실제 모델 파이프라인에서는 1개의 question에 대하여 retrieval에서 top k개의 context를 선택후 모두 concat 하여 MRC의 input으로 사용, 이 때 train에서 1개의 question에 대해서 1개의 context만을 사용하여 정답을 낼 경우 validation을 신용하기 힘들다.
-
1개의 question에 대하여 유사도가 높은 context를 concat하여 사용 할 경우 정답이 cls 토큰인 데이터가 augmentation되는 효과
-
Elastic Search를 이용한 retrieval이 필요
How to use Elastic Search : Elastic Search for Beginners
from elasticsearch import Elasticsearch
config = {'host':'localhost', 'port':9200}
es = Elasticsearch([config])
# test connection
es.ping()
def search_es(es_obj, index_name, question_text, n_results):
query = {
'query': {
'match': {
'document_text': question_text
}
}
}
res = es_obj.search(index=index_name, body=query, size=n_results)
return res
# train_qa : train data list
for step, question in enumerate(train_qa):
k = 5
res = search_es(es, "wiki_split", question["question"], k)
context_list = [(hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']]
add_text = train_qa[step]["context"]
count = 0
for context in context_list:
#같은것이 있을 경우 continue 하여 concat X
if question["context"] == context[0]:
continue
add_text += " " + context[0]
count += 1
if count == 4:
break
train_qa[step]["context"] = add_text
question과 유사한 context K개를 concat한 데이터 생성
-
단순하게 retrieval에서 많은 context를 뽑아 concat하여 MRC모델에게 넘겨주어 MRC에게 많은 부담을 주는 문제 해결
-
MRC의 부담을 줄여주기 위하여 retrieval에 들어가는 context을 문장단위로 split하여 사용
-
trade off였던 retrieval의 성능저하가 크지 않다. (Top 5, 10, 15 기준)
def passage_split(text):
split_length = 400
num = len(text) // split_length
count = 1
split_datas = kss.split_sentences(text)
data_list = []
data = ""
for split_data in split_datas:
if abs(len(data) - split_length) > abs(len(data) + len(split_data) - split_length) and count < num:
if len(data) == 0:
data += split_data
else:
data += (" " + split_data)
elif count < num:
data_list.append(data)
count += 1
data = ""
data += split_data
else:
data += split_data
data_list.append(data)
return data_list, len(data_list)
with open("/opt/ml/input/data/preprocess_wiki.json", "r") as f:
wiki = json.load(f)
new_wiki = dict()
for i in range(len(wiki)):
if len(wiki[str(i)]["text"]) < 800:
new_wiki[str(i)] = wiki[str(i)]
continue
data_list, count = passage_split(wiki[str(i)]["text"])
for j in range(count):
new_wiki[str(i) + f"_{j}"] = {"text" : data_list[j],
"corpus_source" : wiki[str(i)]["corpus_source"],
"url" : wiki[str(i)]["url"],
"domain" : wiki[str(i)]["domain"],
"title" : wiki[str(i)]["title"],
"author" : wiki[str(i)]["author"],
"html" : wiki[str(i)]["html"],
"document_id" : wiki[str(i)]["document_id"]}
passage_split
- context의 길이가 일정 이상(800)일 경우 split
- kss 라이브러리를 사용하여 문장단위로 split
- 받은 context의 길이에 따라 split되는 문장의 수를 달리한다. (최대한 같은 길이로 split)
-
Query Attention model에 사용할 question type이 정의 되어있는 데이터셋 구축
-
AI hub data를 이용하여 Question type을 예측하는 model 학습
-
Question type class : 'work_how':0, 'work_what':1, 'work_when':2, 'work_where':3, 'work_who':4, 'work_why':5
# model : question type classification model (train by AI hub data)
def question_labeling(model, train_iter, val_iter):
train_file = get_pickle("../../data/concat_train.pkl")["train"]
validation_file = get_pickle("../../data/concat_train.pkl")["validation"]
train_qa = [{"id" : train_file[i]["id"],
"question" : train_file[i]["question"],
"answers" : train_file[i]["answers"],
"context" : train_file[i]["context"]} for i in range(len(train_file))]
validation_qa = [{"id" : validation_file[i]["id"],
"question" : validation_file[i]["question"],
"answers" : validation_file[i]["answers"],
"context" : validation_file[i]["context"]} for i in range(len(validation_file))]
device = "cuda:0"
for step, (input_ids, attention_mask, labels) in tqdm(enumerate(train_iter), total=len(train_iter), position=0, leave=True):
score = model(input_ids.to(device), attention_mask=attention_mask.to(device))[0]
pred = torch.argmax(score, 1).detach().cpu().numpy()
train_qa[step]["question_type"] = pred
for step, (input_ids, attention_mask, labels) in tqdm(enumerate(val_iter), total=len(val_iter), position=0, leave=True):
score = model(input_ids.to(device), attention_mask=attention_mask.to(device))[0]
pred = torch.argmax(score, 1).detach().cpu().numpy()
validation_qa[step]["question_type"] = pred
train_df = pd.DataFrame(train_qa)
val_df = pd.DataFrame(validation_qa)
return train_df, val_df
def save_data(train_df, val_df):
train_f = Features({'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None),
'context': Value(dtype='string', id=None),
'id': Value(dtype='string', id=None),
'question': Value(dtype='string', id=None),
'question_type' : Value(dtype='int32', id=None)})
train_datasets = DatasetDict({'train': Dataset.from_pandas(train_df, features=train_f), 'validation': Dataset.from_pandas(val_df, features=train_f)})
file = open("../../data/question_type.pkl", "wb")
pickle.dump(train_datasets, file)
file.close()
question_labeling
AI hub로 사전에 학습된 모델을 이용하여 대회 train data에 대하여 question type을 추가