You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
35 lines
1.3 KiB
35 lines
1.3 KiB
import torch |
|
|
|
from fireredasr.models.module.conformer_encoder import ConformerEncoder |
|
from fireredasr.models.module.transformer_decoder import TransformerDecoder |
|
|
|
|
|
class FireRedAsrAed(torch.nn.Module): |
|
@classmethod |
|
def from_args(cls, args): |
|
return cls(args) |
|
|
|
def __init__(self, args): |
|
super().__init__() |
|
self.sos_id = args.sos_id |
|
self.eos_id = args.eos_id |
|
|
|
self.encoder = ConformerEncoder( |
|
args.idim, args.n_layers_enc, args.n_head, args.d_model, |
|
args.residual_dropout, args.dropout_rate, |
|
args.kernel_size, args.pe_maxlen) |
|
|
|
self.decoder = TransformerDecoder( |
|
args.sos_id, args.eos_id, args.pad_id, args.odim, |
|
args.n_layers_dec, args.n_head, args.d_model, |
|
args.residual_dropout, args.pe_maxlen) |
|
|
|
def transcribe(self, padded_input, input_lengths, |
|
beam_size=1, nbest=1, decode_max_len=0, |
|
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0): |
|
enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths) |
|
nbest_hyps = self.decoder.batch_beam_search( |
|
enc_outputs, enc_mask, |
|
beam_size, nbest, decode_max_len, |
|
softmax_smoothing, length_penalty, eos_penalty) |
|
return nbest_hyps
|
|
|