You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
11 KiB

import logging
import os
import random
import re
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from fireredasr.models.fireredasr_aed import FireRedAsrAed
from fireredasr.models.module.adapter import Adapter
from fireredasr.tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID
from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper
from fireredasr.utils.param import count_model_parameters
class FireRedAsrLlm(nn.Module):
@classmethod
def load_encoder(cls, model_path):
assert os.path.exists(model_path)
package = torch.load(model_path, map_location=lambda storage, loc: storage)
model = FireRedAsrAed.from_args(package["args"])
if "model_state_dict" in package:
model.load_state_dict(package["model_state_dict"], strict=False)
encoder = model.encoder
encoder_dim = encoder.odim
return encoder, encoder_dim
@classmethod
def from_args(cls, args):
logging.info(args)
logging.info("Build FireRedAsrLlm")
# Build Speech Encoder
encoder, encoder_dim = cls.load_encoder(args.encoder_path)
count_model_parameters(encoder)
if args.freeze_encoder:
logging.info(f"Frezee encoder")
for name, param in encoder.named_parameters():
param.requires_grad = False
encoder.eval()
if args.use_flash_attn:
attn_implementation = "flash_attention_2"
if args.use_fp16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
else:
attn_implementation = "eager"
if args.use_fp16:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
# Build LLM
llm = AutoModelForCausalLM.from_pretrained(
args.llm_dir,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
count_model_parameters(llm)
# LLM Freeze or LoRA
llm_dim = llm.config.hidden_size
if args.freeze_llm:
logging.info(f"Frezee LLM")
for name, param in llm.named_parameters():
param.requires_grad = False
llm.eval()
else:
if args.use_lora:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"up_proj",
"gate_proj",
"down_proj",
],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
llm = get_peft_model(llm, lora_config)
llm.print_trainable_parameters()
tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir)
assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>")
llm.config.pad_token_id = tokenizer.pad_token_id
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
DEFAULT_SPEECH_TOKEN
)
# Build projector
encoder_projector = Adapter(
encoder_dim, llm_dim, args.encoder_downsample_rate)
count_model_parameters(encoder_projector)
return cls(encoder, llm, encoder_projector,
args.freeze_encoder, args.freeze_llm)
def __init__(self, encoder, llm, encoder_projector,
freeze_encoder, freeze_llm):
super().__init__()
self.encoder = encoder
self.llm = llm
self.encoder_projector = encoder_projector
# args
self.freeze_encoder = freeze_encoder
self.freeze_llm = freeze_llm
self.llm_config = llm.config
def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask,
beam_size=1, decode_max_len=0, decode_min_len=0,
repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0):
encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths)
speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids)
inputs_embeds, attention_mask, _ = \
self._merge_input_ids_with_speech_features(
speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask,
speech_lens=speech_lens
)
max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len
max_new_tokens = max(1, max_new_tokens)
generated_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=max_new_tokens,
num_beams=beam_size,
do_sample=False,
min_length=decode_min_len,
top_p=1.0,
repetition_penalty=repetition_penalty,
length_penalty=llm_length_penalty,
temperature=temperature,
bos_token_id=self.llm.config.bos_token_id,
eos_token_id=self.llm.config.eos_token_id,
pad_token_id=self.llm.config.pad_token_id,
)
return generated_ids
def _merge_input_ids_with_speech_features(
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None,
speech_lens=None
):
"""
Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
"""
speech_lens = None
num_speechs, speech_len, embed_dim = speech_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
)
# 1. Create a mask to know where special speech tokens are
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_speech_tokens.max() * (speech_len - 1)
) + sequence_length
batch_indices, non_speech_indices = torch.where(
input_ids != self.llm.config.default_speech_token_id
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged speech-text sequence.
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
) # (N,U)
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_speech_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
final_attention_mask = torch.zeros(
batch_size,
max_embed_dim,
dtype=attention_mask.dtype,
device=inputs_embeds.device,
)
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim),
IGNORE_TOKEN_ID,
dtype=input_ids.dtype,
device=input_ids.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_speech_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_speech_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_speech_indices
]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
batch_indices, non_speech_indices
]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[
batch_indices, non_speech_indices
]
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
speech_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
speech_to_overwrite[batch_indices, text_to_overwrite] = False
if speech_lens is not None:
speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None]
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
:, None
].to(target_device)
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
)
final_embedding[speech_to_overwrite] = (
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
)
if speech_lens is not None:
speech_to_overwrite &= speech_pad_position
final_attention_mask |= speech_to_overwrite
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(
input_ids == self.llm.config.pad_token_id
)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels #, position_ids