You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
272 lines
11 KiB
272 lines
11 KiB
import logging |
|
import os |
|
import random |
|
import re |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModelForCausalLM |
|
|
|
from fireredasr.models.fireredasr_aed import FireRedAsrAed |
|
from fireredasr.models.module.adapter import Adapter |
|
from fireredasr.tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID |
|
from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper |
|
from fireredasr.utils.param import count_model_parameters |
|
|
|
|
|
class FireRedAsrLlm(nn.Module): |
|
@classmethod |
|
def load_encoder(cls, model_path): |
|
assert os.path.exists(model_path) |
|
package = torch.load(model_path, map_location=lambda storage, loc: storage) |
|
model = FireRedAsrAed.from_args(package["args"]) |
|
if "model_state_dict" in package: |
|
model.load_state_dict(package["model_state_dict"], strict=False) |
|
encoder = model.encoder |
|
encoder_dim = encoder.odim |
|
return encoder, encoder_dim |
|
|
|
@classmethod |
|
def from_args(cls, args): |
|
logging.info(args) |
|
logging.info("Build FireRedAsrLlm") |
|
# Build Speech Encoder |
|
encoder, encoder_dim = cls.load_encoder(args.encoder_path) |
|
count_model_parameters(encoder) |
|
if args.freeze_encoder: |
|
logging.info(f"Frezee encoder") |
|
for name, param in encoder.named_parameters(): |
|
param.requires_grad = False |
|
encoder.eval() |
|
|
|
if args.use_flash_attn: |
|
attn_implementation = "flash_attention_2" |
|
if args.use_fp16: |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
else: |
|
attn_implementation = "eager" |
|
if args.use_fp16: |
|
torch_dtype = torch.float16 |
|
else: |
|
torch_dtype = torch.float32 |
|
|
|
# Build LLM |
|
llm = AutoModelForCausalLM.from_pretrained( |
|
args.llm_dir, |
|
attn_implementation=attn_implementation, |
|
torch_dtype=torch_dtype, |
|
) |
|
count_model_parameters(llm) |
|
|
|
# LLM Freeze or LoRA |
|
llm_dim = llm.config.hidden_size |
|
if args.freeze_llm: |
|
logging.info(f"Frezee LLM") |
|
for name, param in llm.named_parameters(): |
|
param.requires_grad = False |
|
llm.eval() |
|
else: |
|
if args.use_lora: |
|
from peft import LoraConfig, get_peft_model |
|
lora_config = LoraConfig( |
|
r=64, |
|
lora_alpha=16, |
|
target_modules=[ |
|
"q_proj", |
|
"k_proj", |
|
"v_proj", |
|
"o_proj", |
|
"up_proj", |
|
"gate_proj", |
|
"down_proj", |
|
], |
|
lora_dropout=0.05, |
|
task_type="CAUSAL_LM", |
|
) |
|
llm = get_peft_model(llm, lora_config) |
|
llm.print_trainable_parameters() |
|
|
|
tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir) |
|
assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>") |
|
llm.config.pad_token_id = tokenizer.pad_token_id |
|
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
|
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids( |
|
DEFAULT_SPEECH_TOKEN |
|
) |
|
|
|
# Build projector |
|
encoder_projector = Adapter( |
|
encoder_dim, llm_dim, args.encoder_downsample_rate) |
|
count_model_parameters(encoder_projector) |
|
|
|
return cls(encoder, llm, encoder_projector, |
|
args.freeze_encoder, args.freeze_llm) |
|
|
|
def __init__(self, encoder, llm, encoder_projector, |
|
freeze_encoder, freeze_llm): |
|
super().__init__() |
|
self.encoder = encoder |
|
self.llm = llm |
|
self.encoder_projector = encoder_projector |
|
# args |
|
self.freeze_encoder = freeze_encoder |
|
self.freeze_llm = freeze_llm |
|
self.llm_config = llm.config |
|
|
|
def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask, |
|
beam_size=1, decode_max_len=0, decode_min_len=0, |
|
repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0): |
|
encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths) |
|
speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths) |
|
inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids) |
|
|
|
inputs_embeds, attention_mask, _ = \ |
|
self._merge_input_ids_with_speech_features( |
|
speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask, |
|
speech_lens=speech_lens |
|
) |
|
|
|
max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len |
|
max_new_tokens = max(1, max_new_tokens) |
|
|
|
generated_ids = self.llm.generate( |
|
inputs_embeds=inputs_embeds, |
|
max_new_tokens=max_new_tokens, |
|
num_beams=beam_size, |
|
do_sample=False, |
|
min_length=decode_min_len, |
|
top_p=1.0, |
|
repetition_penalty=repetition_penalty, |
|
length_penalty=llm_length_penalty, |
|
temperature=temperature, |
|
bos_token_id=self.llm.config.bos_token_id, |
|
eos_token_id=self.llm.config.eos_token_id, |
|
pad_token_id=self.llm.config.pad_token_id, |
|
) |
|
|
|
return generated_ids |
|
|
|
|
|
def _merge_input_ids_with_speech_features( |
|
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None, |
|
speech_lens=None |
|
): |
|
""" |
|
Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py |
|
""" |
|
speech_lens = None |
|
num_speechs, speech_len, embed_dim = speech_features.shape |
|
batch_size, sequence_length = input_ids.shape |
|
left_padding = not torch.sum( |
|
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id) |
|
) |
|
# 1. Create a mask to know where special speech tokens are |
|
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id |
|
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1) |
|
# Compute the maximum embed dimension |
|
max_embed_dim = ( |
|
num_special_speech_tokens.max() * (speech_len - 1) |
|
) + sequence_length |
|
batch_indices, non_speech_indices = torch.where( |
|
input_ids != self.llm.config.default_speech_token_id |
|
) |
|
|
|
# 2. Compute the positions where text should be written |
|
# Calculate new positions for text tokens in merged speech-text sequence. |
|
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens. |
|
# `torch.cumsum` computes how each speech token shifts subsequent text token positions. |
|
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
|
new_token_positions = ( |
|
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1 |
|
) # (N,U) |
|
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
|
if left_padding: |
|
new_token_positions += nb_speech_pad[:, None] # offset for left padding |
|
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices] |
|
|
|
# 3. Create the full embedding, already padded to the maximum position |
|
final_embedding = torch.zeros( |
|
batch_size, |
|
max_embed_dim, |
|
embed_dim, |
|
dtype=inputs_embeds.dtype, |
|
device=inputs_embeds.device, |
|
) |
|
final_attention_mask = torch.zeros( |
|
batch_size, |
|
max_embed_dim, |
|
dtype=attention_mask.dtype, |
|
device=inputs_embeds.device, |
|
) |
|
if labels is not None: |
|
final_labels = torch.full( |
|
(batch_size, max_embed_dim), |
|
IGNORE_TOKEN_ID, |
|
dtype=input_ids.dtype, |
|
device=input_ids.device, |
|
) |
|
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
|
# set the corresponding tensors into their correct target device. |
|
target_device = inputs_embeds.device |
|
batch_indices, non_speech_indices, text_to_overwrite = ( |
|
batch_indices.to(target_device), |
|
non_speech_indices.to(target_device), |
|
text_to_overwrite.to(target_device), |
|
) |
|
attention_mask = attention_mask.to(target_device) |
|
|
|
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"] |
|
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features |
|
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[ |
|
batch_indices, non_speech_indices |
|
] |
|
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[ |
|
batch_indices, non_speech_indices |
|
] |
|
if labels is not None: |
|
final_labels[batch_indices, text_to_overwrite] = labels[ |
|
batch_indices, non_speech_indices |
|
] |
|
|
|
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835) |
|
speech_to_overwrite = torch.full( |
|
(batch_size, max_embed_dim), |
|
True, |
|
dtype=torch.bool, |
|
device=inputs_embeds.device, |
|
) |
|
speech_to_overwrite[batch_indices, text_to_overwrite] = False |
|
if speech_lens is not None: |
|
speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None] |
|
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[ |
|
:, None |
|
].to(target_device) |
|
|
|
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel(): |
|
raise ValueError( |
|
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while" |
|
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation." |
|
) |
|
|
|
final_embedding[speech_to_overwrite] = ( |
|
speech_features.contiguous().reshape(-1, embed_dim).to(target_device) |
|
) |
|
if speech_lens is not None: |
|
speech_to_overwrite &= speech_pad_position |
|
final_attention_mask |= speech_to_overwrite |
|
|
|
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. |
|
batch_indices, pad_indices = torch.where( |
|
input_ids == self.llm.config.pad_token_id |
|
) |
|
indices_to_mask = new_token_positions[batch_indices, pad_indices] |
|
|
|
final_embedding[batch_indices, indices_to_mask] = 0 |
|
|
|
if labels is None: |
|
final_labels = None |
|
|
|
return final_embedding, final_attention_mask, final_labels #, position_ids
|
|
|