You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
3.4 KiB
105 lines
3.4 KiB
#!/usr/bin/env python3 |
|
|
|
import argparse |
|
import glob |
|
import os |
|
import sys |
|
|
|
from fireredasr.models.fireredasr import FireRedAsr |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"]) |
|
parser.add_argument('--model_dir', type=str, required=True) |
|
|
|
# Input / Output |
|
parser.add_argument("--wav_path", type=str) |
|
parser.add_argument("--wav_paths", type=str, nargs="*") |
|
parser.add_argument("--wav_dir", type=str) |
|
parser.add_argument("--wav_scp", type=str) |
|
parser.add_argument("--output", type=str) |
|
|
|
# Decode Options |
|
parser.add_argument('--use_gpu', type=int, default=1) |
|
parser.add_argument("--batch_size", type=int, default=1) |
|
parser.add_argument("--beam_size", type=int, default=1) |
|
parser.add_argument("--decode_max_len", type=int, default=0) |
|
# FireRedASR-AED |
|
parser.add_argument("--nbest", type=int, default=1) |
|
parser.add_argument("--softmax_smoothing", type=float, default=1.0) |
|
parser.add_argument("--aed_length_penalty", type=float, default=0.0) |
|
parser.add_argument("--eos_penalty", type=float, default=1.0) |
|
# FireRedASR-LLM |
|
parser.add_argument("--decode_min_len", type=int, default=0) |
|
parser.add_argument("--repetition_penalty", type=float, default=1.0) |
|
parser.add_argument("--llm_length_penalty", type=float, default=0.0) |
|
parser.add_argument("--temperature", type=float, default=1.0) |
|
|
|
|
|
def main(args): |
|
wavs = get_wav_info(args) |
|
fout = open(args.output, "w") if args.output else None |
|
|
|
model = FireRedAsr.from_pretrained(args.asr_type, args.model_dir) |
|
|
|
batch_uttid = [] |
|
batch_wav_path = [] |
|
for i, wav in enumerate(wavs): |
|
uttid, wav_path = wav |
|
batch_uttid.append(uttid) |
|
batch_wav_path.append(wav_path) |
|
if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1: |
|
continue |
|
|
|
results = model.transcribe( |
|
batch_uttid, |
|
batch_wav_path, |
|
{ |
|
"use_gpu": args.use_gpu, |
|
"beam_size": args.beam_size, |
|
"nbest": args.nbest, |
|
"decode_max_len": args.decode_max_len, |
|
"softmax_smoothing": args.softmax_smoothing, |
|
"aed_length_penalty": args.aed_length_penalty, |
|
"eos_penalty": args.eos_penalty, |
|
"decode_min_len": args.decode_min_len, |
|
"repetition_penalty": args.repetition_penalty, |
|
"llm_length_penalty": args.llm_length_penalty, |
|
"temperature": args.temperature |
|
} |
|
) |
|
|
|
for result in results: |
|
print(result) |
|
if fout is not None: |
|
fout.write(f"{result['uttid']}\t{result['text']}\n") |
|
|
|
batch_uttid = [] |
|
batch_wav_path = [] |
|
|
|
|
|
def get_wav_info(args): |
|
""" |
|
Returns: |
|
wavs: list of (uttid, wav_path) |
|
""" |
|
base = lambda p: os.path.basename(p).replace(".wav", "") |
|
if args.wav_path: |
|
wavs = [(base(args.wav_path), args.wav_path)] |
|
elif args.wav_paths and len(args.wav_paths) >= 1: |
|
wavs = [(base(p), p) for p in sorted(args.wav_paths)] |
|
elif args.wav_scp: |
|
wavs = [line.strip().split() for line in open(args.wav_scp)] |
|
elif args.wav_dir: |
|
wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True) |
|
wavs = [(base(p), p) for p in sorted(wavs)] |
|
else: |
|
raise ValueError("Please provide valid wav info") |
|
print(f"#wavs={len(wavs)}") |
|
return wavs |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parser.parse_args() |
|
print(args) |
|
main(args)
|
|
|