You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

35 lines
1.3 KiB

import torch
from fireredasr.models.module.conformer_encoder import ConformerEncoder
from fireredasr.models.module.transformer_decoder import TransformerDecoder
class FireRedAsrAed(torch.nn.Module):
@classmethod
def from_args(cls, args):
return cls(args)
def __init__(self, args):
super().__init__()
self.sos_id = args.sos_id
self.eos_id = args.eos_id
self.encoder = ConformerEncoder(
args.idim, args.n_layers_enc, args.n_head, args.d_model,
args.residual_dropout, args.dropout_rate,
args.kernel_size, args.pe_maxlen)
self.decoder = TransformerDecoder(
args.sos_id, args.eos_id, args.pad_id, args.odim,
args.n_layers_dec, args.n_head, args.d_model,
args.residual_dropout, args.pe_maxlen)
def transcribe(self, padded_input, input_lengths,
beam_size=1, nbest=1, decode_max_len=0,
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths)
nbest_hyps = self.decoder.batch_beam_search(
enc_outputs, enc_mask,
beam_size, nbest, decode_max_len,
softmax_smoothing, length_penalty, eos_penalty)
return nbest_hyps