基于大语言模型实现文本端点检测

# 前言 在语音对话识别中,一般使用VAD检测用户时候结束说话,但是这个结束时间长度设置多少合适,这很难抉择,太短了,用户说话慢就容易打断,太长了用户等待时间久。还有常见的情况,用户在说话的时候,中途停顿了一下思考,如果只是使用VAD检测,有可能就会认为说话结束,但是用户还没有说话,这句话也不完整。这种情况可以配合文本端点检测,在使用VAD检测的时候,配合文本端点检测,从而保证用户表达完整。

原理

使用大语言模型检测一句话是否完整,其实原理很简单,就是利用了大语言模型的文本生成功能,让模型推理下一个token是否为结束符即可,或者下一个token是结束符的概率是多少。

如下两个例子,输入一个缺少结束符<|im_end|>的文本,让模型预测下一个字符,如果是结束字符,就代表这句话已经说完。模型之所以能够这样输出,是因为我们微调模型的时候输入的格式都是<|im_start|><|user|>你叫什么名字<|im_end|><|im_start|><|assistant|>我叫夜雨飘零<|im_end|>,所以模型可以按照这个格式推理输出。

未结束的一段话,模型识别情况

# 例如输入的文本如下:
<|im_start|><|user|>你叫什么名
# 模型输出
<|im_start|><|user|>你叫什么名字

完整的一句,模型识别的情况。

# 例如输入的文本如下:
<|im_start|><|user|>你叫什么名字
# 模型输出
<|im_start|><|user|>你叫什么名字<|im_end|>

准备数据

这里可以推荐一个数据集alpaca_gpt4_zh,但这个数据还需要处理一下,训练的数据格式与平时微调大语言模型的数据格式一样,不过为了训练更快,和更适合使用情况,主要做以下一下调整:

  1. instruction字段不要太长,一般用户输入也会长篇大论。
  2. output字段可以设置为空,这一步可以使得输入数据大大缩短,提高训练速度。

例子数据如下:

[
    {
        "instruction": "我们如何在日常生活中减少用水",
        "input": "",
        "output": ""
    },
    {
        "instruction": "编辑文章使其更吸引读者",
        "input": "",
        "output": ""
    },
]

微调模型

微调模型博主使用工具是LLaMA-Factory,具体怎么使用,开发者们可以自行查看官方文档,这里就不赘述了,只介绍需要的数据格式和几个训练注意事项。

如图所示,这里选择了一个非常小的Qwen模型,因为要追求推理速度。然后是选择提示模板,这个很重要,不同的模型的模板不一样。 最后就是选择自己处理后的数据。

训练结束之后,记得要选择检查点路径,并导出模型。

量化模型

为了CPU下追求更快的速度,可以导出为ONNX模型,同时也量化模型,注意Windows使用量化模型推理结果可能不准确

转换代码如下,这里通过转换为量化和未量化的ONNX模型。

import argparse

from optimum.onnxruntime import ORTQuantizer, ORTModelForCausalLM
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from transformers import AutoTokenizer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="models/Qwen2.5-0.5B-Instruct")
    parser.add_argument("--save_dir", type=str, default="models/Qwen2.5-0.5B-Instruct-onnx")
    args = parser.parse_args()

    # Load a model from transformers and export it to ONNX
    ort_model = ORTModelForCausalLM.from_pretrained(args.model,
                                                    export=True,
                                                    task="text-generation-with-past")
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    ort_model.save_pretrained(args.save_dir)
    tokenizer.save_pretrained(args.save_dir)

    dqconfig = AutoQuantizationConfig.avx2(is_static=False,
                                           per_channel=False,
                                           use_symmetric_activations=True)

    print("量化模型...")
    quantizer = ORTQuantizer.from_pretrained(ort_model)
    model_quantized_path = quantizer.quantize(save_dir=args.save_dir + "-quantize",
                                              quantization_config=dqconfig)

文本端点检测

使用代码如下,如原理介绍部分所说,让模型推理下一个token,并获取该token为结束符<|im_end|>得分是多少,通过这个得分判断是否为文本结束,博主建议阈值为0.15,当然也可以使用其他值。

import numpy as np
import torch
from loguru import logger
from optimum.onnxruntime import ORTModelForCausalLM
from sklearn.utils.extmath import softmax
from transformers import AutoTokenizer


class TurnDetector:
    def __init__(self, model_path):
        self.model = ORTModelForCausalLM.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.eou_index = self.tokenizer.encode("<|im_end|>")[-1]
        logger.info(f"结束符的token为:{self.eou_index}")

    def __call__(self, text):
        messages = [{"role": "user", "content": text}]
        text = self.tokenizer.apply_chat_template(messages,
                                                  add_generation_prompt=False,
                                                  add_special_tokens=False,
                                                  tokenize=False)
        ix = text.rfind("<|im_end|>")
        text = text[:ix]

        model_inputs = self.tokenizer(text, return_tensors="pt")
        position_ids = torch.arange(model_inputs['input_ids'].shape[1], dtype=torch.int64)[None, :]
        model_inputs['position_ids'] = position_ids

        outputs = self.model(**model_inputs).logits.numpy()

        logits = outputs[0, -1, :]
        probs = softmax(logits[np.newaxis, :])[0]
        eou_probability = probs[self.eou_index]

        return eou_probability


if __name__ == '__main__':
    turn_detector = TurnDetector("models/Qwen2.5-0.5B-Instruct-finetune-onnx")

    s = "你叫什"
    result = turn_detector(s)
    print(f'[{s}] 识别结果为:{result:.3f}')

    s = "你叫什么名字"
    result = turn_detector(s)
    print(f'[{s}] 识别结果为:{result:.3f}')

输出结果:

2025-01-17 20:16:57.416 | INFO     | __main__:__init__:14 - 结束符的token为:151645
[你叫什] 识别结果为:0.000
[你叫什么名字] 识别结果为:0.571