You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.4 KiB

import logging
import re
import sentencepiece as spm
from fireredasr.data.token_dict import TokenDict
class ChineseCharEnglishSpmTokenizer:
"""
- One Chinese char is a token.
- Split English word into SPM and one piece is a token.
- Ignore ' ' between Chinese char
- Replace ' ' between English word with "" by spm_model
- Need to put SPM piece into dict file
- If not set spm_model, will use English char and <space>
"""
SPM_SPACE = ""
def __init__(self, dict_path, spm_model, unk="<unk>", space="<space>"):
self.dict = TokenDict(dict_path, unk=unk)
self.space = space
if spm_model:
self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model)
else:
self.sp = None
print("[WRAN] Not set spm_model, will use English char")
print("[WARN] Please check how to deal with ' '(space)")
if self.space not in self.dict:
print("Please add <space> to your dict, or it will be <unk>")
def tokenize(self, text, replace_punc=True):
#if text == "":
# logging.info(f"empty text")
text = text.upper()
tokens = []
if replace_punc:
text = re.sub("[,。?!,\.?!]", " ", text)
pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])')
parts = pattern.split(text.strip())
parts = [p for p in parts if len(p.strip()) > 0]
for part in parts:
if pattern.fullmatch(part) is not None:
tokens.append(part)
else:
if self.sp:
for piece in self.sp.EncodeAsPieces(part.strip()):
tokens.append(piece)
else:
for char in part.strip():
tokens.append(char if char != " " else self.space)
tokens_id = []
for token in tokens:
tokens_id.append(self.dict.get(token, self.dict.unk))
return tokens, tokens_id
def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
"""inputs is ids or tokens, do not need self.sp"""
if len(inputs) > 0 and type(inputs[0]) == int:
tokens = [self.dict[id] for id in inputs]
else:
tokens = inputs
s = f"{join_symbol}".join(tokens)
if replace_spm_space:
s = s.replace(self.SPM_SPACE, ' ').strip()
return s