-
Notifications
You must be signed in to change notification settings - Fork 4
/
xer.py
executable file
·76 lines (61 loc) · 2.51 KB
/
xer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
@author
______ _ _
| ____| (_) /\ | |
| |__ __ _ _ __ _ ___ / \ | | __ _ ___ _ __ ___ __ _ _ __ _ _
| __/ _` | '__| / __| / /\ \ | |/ _` / __| '_ ` _ \ / _` | '__| | | |
| | | (_| | | | \__ \ / ____ \| | (_| \__ \ | | | | | (_| | | | |_| |
|_| \__,_|_| |_|___/ /_/ \_\_|\__,_|___/_| |_| |_|\__,_|_| \__, |
__/ |
|___/
Email: [email protected]
Date: Mar 15, 2022
"""
# pip install git+https://github.com/pzelasko/kaldialign.git
from kaldialign import edit_distance
def cer(ref, hyp):
"""
Computes the Character Error Rate, defined as the edit distance.
Arguments:
ref (string): a space-separated ground truth string
hyp (string): a space-separated hypothesis
"""
ref, hyp, = ref.replace(' ', '').strip(), hyp.replace(' ', '').strip()
info = edit_distance(ref, hyp)
distance = info['total']
ref_length = float(len(ref))
data = {
'insertions': info['ins'],
'deletions': info['del'],
'substitutions': info['sub'],
'distance': distance,
'ref_length': ref_length,
'Error Rate': (distance / ref_length) * 100
}
return data
def wer(ref, hyp):
"""
Computes the Word Error Rate, defined as the edit distance between the
two provided sentences after tokenizing to words.
Arguments:
ref (string): a space-separated ground truth string
hyp (string): a space-separated hypothesis
"""
# build mapping of words to integers
b = set(ref.split() + hyp.split())
word2char = dict(zip(b, range(len(b))))
# map the words to a char array (Levenshtein packages only accepts strings)
w1 = [chr(word2char[w]) for w in ref.split()]
w2 = [chr(word2char[w]) for w in hyp.split()]
info = edit_distance(''.join(w1), ''.join(w2))
distance = info['total']
ref_length = float(len(w1))
data = {
'insertions': info['ins'],
'deletions': info['del'],
'substitutions': info['sub'],
'distance': distance,
'ref_length': ref_length,
'Error Rate': (distance / ref_length) * 100
}
return data