|
|
#!/usr/bin/env python3 |
|
|
|
|
|
import argparse |
|
|
import re |
|
|
from collections import OrderedDict |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--ref", type=str, required=True) |
|
|
parser.add_argument("--hyp", type=str, required=True) |
|
|
parser.add_argument("--print_sentence_wer", type=int, default=0) |
|
|
parser.add_argument("--do_tn", type=int, default=0, help="simple tn by cn2an") |
|
|
parser.add_argument("--rm_special", type=int, default=0, help="remove <\|.*?\|>") |
|
|
|
|
|
|
|
|
def main(args): |
|
|
uttid2refs = read_uttid2tokens(args.ref, args.do_tn, args.rm_special) |
|
|
uttid2hyps = read_uttid2tokens(args.hyp, args.do_tn, args.rm_special) |
|
|
uttid2wer_info, wer_stat, en_dig_stat = compute_uttid2wer_info( |
|
|
uttid2refs, uttid2hyps, args.print_sentence_wer) |
|
|
wer_stat.print() |
|
|
en_dig_stat.print() |
|
|
|
|
|
|
|
|
def read_uttid2tokens(filename, do_tn=False, rm_special=False): |
|
|
print(f">>> Read uttid to tokens: {filename}", flush=True) |
|
|
uttid2tokens = OrderedDict() |
|
|
uttid2text = read_uttid2text(filename, do_tn, rm_special) |
|
|
for uttid, text in uttid2text.items(): |
|
|
tokens = text2tokens(text) |
|
|
uttid2tokens[uttid] = tokens |
|
|
return uttid2tokens |
|
|
|
|
|
|
|
|
def read_uttid2text(filename, do_tn=False, rm_special=False): |
|
|
uttid2text = OrderedDict() |
|
|
with open(filename, "r", encoding="utf8") as fin: |
|
|
for i, line in enumerate(fin): |
|
|
cols = line.split() |
|
|
if len(cols) == 0: |
|
|
print("[WARN] empty line, continue", i, flush=True) |
|
|
continue |
|
|
assert cols[0] not in uttid2text, f"repeated uttid: {line}" |
|
|
if len(cols) == 1: |
|
|
uttid2text[cols[0]] = "" |
|
|
continue |
|
|
txt = " ".join(cols[1:]) |
|
|
if rm_special: |
|
|
txt = " ".join([t for t in re.split("<\|.*?\|>", txt) if t.strip() != ""]) |
|
|
if do_tn: |
|
|
import cn2an |
|
|
txt = cn2an.transform(txt, "an2cn") |
|
|
uttid2text[cols[0]] = txt |
|
|
return uttid2text |
|
|
|
|
|
|
|
|
def text2tokens(text): |
|
|
PUNCTUATIONS = ",。?!,\.?!"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·。\":" + "()\[\]{}/;`|=+" |
|
|
if text == "": |
|
|
return [] |
|
|
tokens = [] |
|
|
|
|
|
text = re.sub("<unk>", "", text) |
|
|
text = re.sub(r"[%s]+" % PUNCTUATIONS, " ", text) |
|
|
|
|
|
pattern = re.compile(r'([\u4e00-\u9fff])') |
|
|
parts = pattern.split(text.strip().upper()) |
|
|
parts = [p for p in parts if len(p.strip()) > 0] |
|
|
for part in parts: |
|
|
if pattern.fullmatch(part) is not None: |
|
|
tokens.append(part) |
|
|
else: |
|
|
for word in part.strip().split(): |
|
|
tokens.append(word) |
|
|
return tokens |
|
|
|
|
|
|
|
|
def compute_uttid2wer_info(refs, hyps, print_sentence_wer=False): |
|
|
print(f">>> Compute uttid to wer info", flush=True) |
|
|
|
|
|
uttid2wer_info = OrderedDict() |
|
|
wer_stat = WerStats() |
|
|
en_dig_stat = EnDigStats() |
|
|
|
|
|
for uttid, ref in refs.items(): |
|
|
if uttid not in hyps: |
|
|
print(f"[WARN] No hyp for {uttid}", flush=True) |
|
|
continue |
|
|
hyp = hyps[uttid] |
|
|
|
|
|
if len(hyp) - len(ref) >= 8: |
|
|
print(f"[BidLengthDiff]: {uttid} {len(ref)} {len(hyp)}#{' '.join(ref)}#{' '.join(hyp)}") |
|
|
#continue |
|
|
|
|
|
wer_info = compute_one_wer_info(ref, hyp) |
|
|
uttid2wer_info[uttid] = wer_info |
|
|
ns = count_english_ditgit(ref, hyp, wer_info) |
|
|
wer_stat.add(wer_info) |
|
|
en_dig_stat.add(*ns) |
|
|
if print_sentence_wer: |
|
|
print(f"{uttid} {wer_info}") |
|
|
|
|
|
return uttid2wer_info, wer_stat, en_dig_stat |
|
|
|
|
|
|
|
|
COST_SUB = 3 |
|
|
COST_DEL = 3 |
|
|
COST_INS = 3 |
|
|
|
|
|
ALIGN_CRT = 0 |
|
|
ALIGN_SUB = 1 |
|
|
ALIGN_DEL = 2 |
|
|
ALIGN_INS = 3 |
|
|
ALIGN_END = 4 |
|
|
|
|
|
|
|
|
def compute_one_wer_info(ref, hyp): |
|
|
"""Impl minimum edit distance and backtrace. |
|
|
Args: |
|
|
ref, hyp: List[str] |
|
|
Returns: |
|
|
WerInfo |
|
|
""" |
|
|
ref_len = len(ref) |
|
|
hyp_len = len(hyp) |
|
|
|
|
|
class _DpPoint: |
|
|
def __init__(self, cost, align): |
|
|
self.cost = cost |
|
|
self.align = align |
|
|
|
|
|
dp = [] |
|
|
for i in range(0, ref_len + 1): |
|
|
dp.append([]) |
|
|
for j in range(0, hyp_len + 1): |
|
|
dp[-1].append(_DpPoint(i * j, ALIGN_CRT)) |
|
|
|
|
|
# Initialize |
|
|
for i in range(1, hyp_len + 1): |
|
|
dp[0][i].cost = dp[0][i - 1].cost + COST_INS; |
|
|
dp[0][i].align = ALIGN_INS |
|
|
for i in range(1, ref_len + 1): |
|
|
dp[i][0].cost = dp[i - 1][0].cost + COST_DEL |
|
|
dp[i][0].align = ALIGN_DEL |
|
|
|
|
|
# DP |
|
|
for i in range(1, ref_len + 1): |
|
|
for j in range(1, hyp_len + 1): |
|
|
min_cost = 0 |
|
|
min_align = ALIGN_CRT |
|
|
if hyp[j - 1] == ref[i - 1]: |
|
|
min_cost = dp[i - 1][j - 1].cost |
|
|
min_align = ALIGN_CRT |
|
|
else: |
|
|
min_cost = dp[i - 1][j - 1].cost + COST_SUB |
|
|
min_align = ALIGN_SUB |
|
|
|
|
|
del_cost = dp[i - 1][j].cost + COST_DEL |
|
|
if del_cost < min_cost: |
|
|
min_cost = del_cost |
|
|
min_align = ALIGN_DEL |
|
|
|
|
|
ins_cost = dp[i][j - 1].cost + COST_INS |
|
|
if ins_cost < min_cost: |
|
|
min_cost = ins_cost |
|
|
min_align = ALIGN_INS |
|
|
|
|
|
dp[i][j].cost = min_cost |
|
|
dp[i][j].align = min_align |
|
|
|
|
|
# Backtrace |
|
|
crt = sub = ins = det = 0 |
|
|
i = ref_len |
|
|
j = hyp_len |
|
|
align = [] |
|
|
while i > 0 or j > 0: |
|
|
if dp[i][j].align == ALIGN_CRT: |
|
|
align.append((i, j, ALIGN_CRT)) |
|
|
i -= 1 |
|
|
j -= 1 |
|
|
crt += 1 |
|
|
elif dp[i][j].align == ALIGN_SUB: |
|
|
align.append((i, j, ALIGN_SUB)) |
|
|
i -= 1 |
|
|
j -= 1 |
|
|
sub += 1 |
|
|
elif dp[i][j].align == ALIGN_DEL: |
|
|
align.append((i, j, ALIGN_DEL)) |
|
|
i -= 1 |
|
|
det += 1 |
|
|
elif dp[i][j].align == ALIGN_INS: |
|
|
align.append((i, j, ALIGN_INS)) |
|
|
j -= 1 |
|
|
ins += 1 |
|
|
|
|
|
err = sub + det + ins |
|
|
align.reverse() |
|
|
wer_info = WerInfo(ref_len, err, crt, sub, det, ins, align) |
|
|
return wer_info |
|
|
|
|
|
|
|
|
|
|
|
class WerInfo: |
|
|
def __init__(self, ref, err, crt, sub, dele, ins, ali): |
|
|
self.r = ref |
|
|
self.e = err |
|
|
self.c = crt |
|
|
self.s = sub |
|
|
self.d = dele |
|
|
self.i = ins |
|
|
self.ali = ali |
|
|
r = max(self.r, 1) |
|
|
self.wer = 100.0 * (self.s + self.d + self.i) / r |
|
|
|
|
|
def __repr__(self): |
|
|
s = f"wer {self.wer:.2f} ref {self.r:2d} sub {self.s:2d} del {self.d:2d} ins {self.i:2d}" |
|
|
return s |
|
|
|
|
|
|
|
|
class WerStats: |
|
|
def __init__(self): |
|
|
self.infos = [] |
|
|
|
|
|
def add(self, wer_info): |
|
|
self.infos.append(wer_info) |
|
|
|
|
|
def print(self): |
|
|
r = sum(info.r for info in self.infos) |
|
|
if r <= 0: |
|
|
print(f"REF len is {r}, check") |
|
|
r = 1 |
|
|
s = sum(info.s for info in self.infos) |
|
|
d = sum(info.d for info in self.infos) |
|
|
i = sum(info.i for info in self.infos) |
|
|
se = 100.0 * s / r |
|
|
de = 100.0 * d / r |
|
|
ie = 100.0 * i / r |
|
|
wer = 100.0 * (s + d + i) / r |
|
|
sen = max(len(self.infos), 1) |
|
|
errsen = sum(info.e > 0 for info in self.infos) |
|
|
ser = 100.0 * errsen / sen |
|
|
print("-"*80) |
|
|
print(f"ref{r:6d} sub{s:6d} del{d:6d} ins{i:6d}") |
|
|
print(f"WER{wer:6.2f} sub{se:6.2f} del{de:6.2f} ins{ie:6.2f}") |
|
|
print(f"SER{ser:6.2f} = {errsen} / {sen}") |
|
|
print("-"*80) |
|
|
|
|
|
|
|
|
class EnDigStats: |
|
|
def __init__(self): |
|
|
self.n_en_word = 0 |
|
|
self.n_en_correct = 0 |
|
|
self.n_dig_word = 0 |
|
|
self.n_dig_correct = 0 |
|
|
|
|
|
def add(self, n_en_word, n_en_correct, n_dig_word, n_dig_correct): |
|
|
self.n_en_word += n_en_word |
|
|
self.n_en_correct += n_en_correct |
|
|
self.n_dig_word += n_dig_word |
|
|
self.n_dig_correct += n_dig_correct |
|
|
|
|
|
def print(self): |
|
|
print(f"English #word={self.n_en_word}, #correct={self.n_en_correct}\n" |
|
|
f"Digit #word={self.n_dig_word}, #correct={self.n_dig_correct}") |
|
|
print("-"*80) |
|
|
|
|
|
|
|
|
|
|
|
def count_english_ditgit(ref, hyp, wer_info): |
|
|
patt_en = "[a-zA-Z\.\-\']+" |
|
|
patt_dig = "[0-9]+" |
|
|
patt_cjk = re.compile(r'([\u4e00-\u9fff])') |
|
|
n_en_word = 0 |
|
|
n_en_correct = 0 |
|
|
n_dig_word = 0 |
|
|
n_dig_correct = 0 |
|
|
ali = wer_info.ali |
|
|
for i, token in enumerate(ref): |
|
|
if re.match(patt_en, token): |
|
|
n_en_word += 1 |
|
|
for y in ali: |
|
|
if y[0] == i+1 and y[2] == ALIGN_CRT: |
|
|
j = y[1] - 1 |
|
|
n_en_correct += 1 |
|
|
break |
|
|
if re.match(patt_dig, token): |
|
|
n_dig_word += 1 |
|
|
for y in ali: |
|
|
if y[0] == i+1 and y[2] == ALIGN_CRT: |
|
|
j = y[1] - 1 |
|
|
n_dig_correct += 1 |
|
|
break |
|
|
if not re.match(patt_cjk, token) and not re.match(patt_en, token) \ |
|
|
and not re.match(patt_dig, token): |
|
|
print("[WiredChar]:", token) |
|
|
return n_en_word, n_en_correct, n_dig_word, n_dig_correct |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parser.parse_args() |
|
|
print(args, flush=True) |
|
|
main(args)
|
|
|
|