You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
5.0 KiB

import os
import time
import torch
from fireredasr.data.asr_feat import ASRFeatExtractor
from fireredasr.models.fireredasr_aed import FireRedAsrAed
from fireredasr.models.fireredasr_llm import FireRedAsrLlm
from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper
class FireRedAsr:
@classmethod
def from_pretrained(cls, asr_type, model_dir):
assert asr_type in ["aed", "llm"]
cmvn_path = os.path.join(model_dir, "cmvn.ark")
feat_extractor = ASRFeatExtractor(cmvn_path)
if asr_type == "aed":
model_path = os.path.join(model_dir, "model.pth.tar")
dict_path =os.path.join(model_dir, "dict.txt")
spm_model = os.path.join(model_dir, "train_bpe1000.model")
model = load_fireredasr_aed_model(model_path)
tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model)
elif asr_type == "llm":
model_path = os.path.join(model_dir, "model.pth.tar")
encoder_path = os.path.join(model_dir, "asr_encoder.pth.tar")
llm_dir = os.path.join(model_dir, "Qwen2-7B-Instruct")
model, tokenizer = load_firered_llm_model_and_tokenizer(
model_path, encoder_path, llm_dir)
model.eval()
return cls(asr_type, feat_extractor, model, tokenizer)
def __init__(self, asr_type, feat_extractor, model, tokenizer):
self.asr_type = asr_type
self.feat_extractor = feat_extractor
self.model = model
self.tokenizer = tokenizer
@torch.no_grad()
def transcribe(self, batch_uttid, batch_wav_path, args={}):
feats, lengths, durs = self.feat_extractor(batch_wav_path)
total_dur = sum(durs)
if args.get("use_gpu", False):
feats, lengths = feats.cuda(), lengths.cuda()
self.model.cuda()
else:
self.model.cpu()
if self.asr_type == "aed":
start_time = time.time()
hyps = self.model.transcribe(
feats, lengths,
args.get("beam_size", 1),
args.get("nbest", 1),
args.get("decode_max_len", 0),
args.get("softmax_smoothing", 1.0),
args.get("aed_length_penalty", 0.0),
args.get("eos_penalty", 1.0)
)
elapsed = time.time() - start_time
rtf= elapsed / total_dur if total_dur > 0 else 0
results = []
for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps):
hyp = hyp[0] # only return 1-best
hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
text = self.tokenizer.detokenize(hyp_ids)
results.append({"uttid": uttid, "text": text, "wav": wav,
"rtf": f"{rtf:.4f}"})
return results
elif self.asr_type == "llm":
input_ids, attention_mask, _, _ = \
LlmTokenizerWrapper.preprocess_texts(
origin_texts=[""]*feats.size(0), tokenizer=self.tokenizer,
max_len=128, decode=True)
if args.get("use_gpu", False):
input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
start_time = time.time()
generated_ids = self.model.transcribe(
feats, lengths, input_ids, attention_mask,
args.get("beam_size", 1),
args.get("decode_max_len", 0),
args.get("decode_min_len", 0),
args.get("repetition_penalty", 1.0),
args.get("llm_length_penalty", 0.0),
args.get("temperature", 1.0)
)
elapsed = time.time() - start_time
rtf= elapsed / total_dur if total_dur > 0 else 0
texts = self.tokenizer.batch_decode(generated_ids,
skip_special_tokens=True)
results = []
for uttid, wav, text in zip(batch_uttid, batch_wav_path, texts):
results.append({"uttid": uttid, "text": text, "wav": wav,
"rtf": f"{rtf:.4f}"})
return results
def load_fireredasr_aed_model(model_path):
package = torch.load(model_path, map_location=lambda storage, loc: storage)
print("model args:", package["args"])
model = FireRedAsrAed.from_args(package["args"])
model.load_state_dict(package["model_state_dict"], strict=True)
return model
def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir):
package = torch.load(model_path, map_location=lambda storage, loc: storage)
package["args"].encoder_path = encoder_path
package["args"].llm_dir = llm_dir
print("model args:", package["args"])
model = FireRedAsrLlm.from_args(package["args"])
model.load_state_dict(package["model_state_dict"], strict=False)
tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(llm_dir)
return model, tokenizer