-
Notifications
You must be signed in to change notification settings - Fork 4
/
data_utils_2.py
62 lines (48 loc) · 1.6 KB
/
data_utils_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz '
import numpy as np
import pickle
'''
get metadata - lookup
idx2w and w2idx
'''
def get_metadata():
with open('./metadata.pkl', 'rb') as f:
metadata = pickle.load(f)
return metadata.get('idx2w'), metadata.get('w2idx'), metadata.get('limit')
'''
a generic decode function
inputs : sequence, lookup
'''
def decode(sequence, lookup, separator=' '): # 0 used for padding, is ignored
# print ([ lookup[element] for element in sequence if element ])
return separator.join([ lookup[element] for element in sequence if element ])
'''
encode function
inputs : sentence, lookup
'''
def encode(sentence, lookup, maxlen, whitelist=EN_WHITELIST, separator=''):
# to lower case
sentence = sentence.lower()
# allow only characters that are on whitelist
sentence = ''.join( [ ch for ch in sentence if ch in whitelist ] )
# words to indices
indices_x = [ token for token in sentence.strip().split(' ') ]
# clip the sentence to fit model (#words)
indices_x = indices_x[-maxlen:] if len(indices_x) > maxlen else indices_x
# zero pad
idx_x = np.array(pad_seq(indices_x, lookup, maxlen))
# reshape
return idx_x.reshape([maxlen, 1])
'''
replace words with indices in a sequence
replace with unknown if word not in lookup
return [list of indices]
'''
def pad_seq(seq, lookup, maxlen):
indices = []
for word in seq:
if word in lookup:
indices.append(lookup[word])
else:
indices.append(lookup['unk'])
return indices + [0]*(maxlen - len(seq))