# 前言 在语音对话识别中,一般使用VAD检测用户时候结束说话,但是这个结束时间长度设置多少合适,这很难抉择,太短了,用户说话慢就容易打断,太长了用户等待时间久。还有常见的情况,用户在说话的时候,中途停顿了一下思考,如果只是使用VAD检测,有可能就会认为说话结束,但是用户还没有说话,这句话也不完整。这种情况可以配合文本端点检测,在使用VAD检测的时候,配合文本端点检测,从而保证用户表达完整。
原理
使用大语言模型检测一句话是否完整,其实原理很简单,就是利用了大语言模型的文本生成功能,让模型推理下一个token是否为结束符即可,或者下一个token是结束符的概率是多少。
如下两个例子,输入一个缺少结束符<|im_end|>
的文本,让模型预测下一个字符,如果是结束字符,就代表这句话已经说完。模型之所以能够这样输出,是因为我们微调模型的时候输入的格式都是<|im_start|><|user|>你叫什么名字<|im_end|><|im_start|><|assistant|>我叫夜雨飘零<|im_end|>
,所以模型可以按照这个格式推理输出。
未结束的一段话,模型识别情况
# 例如输入的文本如下:
<|im_start|><|user|>你叫什么名
# 模型输出
<|im_start|><|user|>你叫什么名字
完整的一句,模型识别的情况。
# 例如输入的文本如下:
<|im_start|><|user|>你叫什么名字
# 模型输出
<|im_start|><|user|>你叫什么名字<|im_end|>
准备数据
这里可以推荐一个数据集alpaca_gpt4_zh,但这个数据还需要处理一下,训练的数据格式与平时微调大语言模型的数据格式一样,不过为了训练更快,和更适合使用情况,主要做以下一下调整:
instruction
字段不要太长,一般用户输入也会长篇大论。output
字段可以设置为空,这一步可以使得输入数据大大缩短,提高训练速度。
例子数据如下:
[
{
"instruction": "我们如何在日常生活中减少用水",
"input": "",
"output": ""
},
{
"instruction": "编辑文章使其更吸引读者",
"input": "",
"output": ""
},
]
微调模型
微调模型博主使用工具是LLaMA-Factory,具体怎么使用,开发者们可以自行查看官方文档,这里就不赘述了,只介绍需要的数据格式和几个训练注意事项。
如图所示,这里选择了一个非常小的Qwen模型,因为要追求推理速度。然后是选择提示模板,这个很重要,不同的模型的模板不一样。 最后就是选择自己处理后的数据。
训练结束之后,记得要选择检查点路径,并导出模型。
量化模型
为了CPU下追求更快的速度,可以导出为ONNX模型,同时也量化模型,注意Windows使用量化模型推理结果可能不准确。
转换代码如下,这里通过转换为量化和未量化的ONNX模型。
import argparse
from optimum.onnxruntime import ORTQuantizer, ORTModelForCausalLM
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from transformers import AutoTokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="models/Qwen2.5-0.5B-Instruct")
parser.add_argument("--save_dir", type=str, default="models/Qwen2.5-0.5B-Instruct-onnx")
args = parser.parse_args()
# Load a model from transformers and export it to ONNX
ort_model = ORTModelForCausalLM.from_pretrained(args.model,
export=True,
task="text-generation-with-past")
tokenizer = AutoTokenizer.from_pretrained(args.model)
ort_model.save_pretrained(args.save_dir)
tokenizer.save_pretrained(args.save_dir)
dqconfig = AutoQuantizationConfig.avx2(is_static=False,
per_channel=False,
use_symmetric_activations=True)
print("量化模型...")
quantizer = ORTQuantizer.from_pretrained(ort_model)
model_quantized_path = quantizer.quantize(save_dir=args.save_dir + "-quantize",
quantization_config=dqconfig)
文本端点检测
使用代码如下,如原理介绍部分所说,让模型推理下一个token,并获取该token为结束符<|im_end|>
得分是多少,通过这个得分判断是否为文本结束,博主建议阈值为0.15,当然也可以使用其他值。
import numpy as np
import torch
from loguru import logger
from optimum.onnxruntime import ORTModelForCausalLM
from sklearn.utils.extmath import softmax
from transformers import AutoTokenizer
class TurnDetector:
def __init__(self, model_path):
self.model = ORTModelForCausalLM.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.eou_index = self.tokenizer.encode("<|im_end|>")[-1]
logger.info(f"结束符的token为:{self.eou_index}")
def __call__(self, text):
messages = [{"role": "user", "content": text}]
text = self.tokenizer.apply_chat_template(messages,
add_generation_prompt=False,
add_special_tokens=False,
tokenize=False)
ix = text.rfind("<|im_end|>")
text = text[:ix]
model_inputs = self.tokenizer(text, return_tensors="pt")
position_ids = torch.arange(model_inputs['input_ids'].shape[1], dtype=torch.int64)[None, :]
model_inputs['position_ids'] = position_ids
outputs = self.model(**model_inputs).logits.numpy()
logits = outputs[0, -1, :]
probs = softmax(logits[np.newaxis, :])[0]
eou_probability = probs[self.eou_index]
return eou_probability
if __name__ == '__main__':
turn_detector = TurnDetector("models/Qwen2.5-0.5B-Instruct-finetune-onnx")
s = "你叫什"
result = turn_detector(s)
print(f'[{s}] 识别结果为:{result:.3f}')
s = "你叫什么名字"
result = turn_detector(s)
print(f'[{s}] 识别结果为:{result:.3f}')
输出结果:
2025-01-17 20:16:57.416 | INFO | __main__:__init__:14 - 结束符的token为:151645
[你叫什] 识别结果为:0.000
[你叫什么名字] 识别结果为:0.571