Preface

In speech recognition, the model outputs only a plain text result without adding punctuation marks according to grammar. This tutorial addresses this issue by adding punctuation marks to speech recognition text based on grammatical context, enabling the speech recognition system to output the final result with punctuation.

Note: This tutorial only introduces usage. For training your own model, please refer to the article “Training a Chinese Punctuation Model Based on PaddlePaddle” (https://blog.csdn.net/qq_33200967/article/details/126858763).

Usage

The usage mainly consists of 4 steps:

  1. First, download the model and extract it to the models/ directory. The download link is as follows:
https://download.csdn.net/download/qq_33200967/75664996
  1. The PaddleNLP tool is required, so install it in advance using the following command:
python -m pip install paddlenlp -i https://mirrors.aliyun.com/pypi/simple/
  1. PPASR has provided a tool for automatic punctuation addition starting from version 0.1.3. You can install ppasr to use this tool:
python -m pip install ppasr==2.4.8 -i https://mirrors.aliyun.com/pypi/simple/
  1. Adding punctuation to text automatically is very simple. Note that the import path and name have changed after version V2.
from ppasr.infer_utils.pun_predictor import PunctuationPredictor

pun_predictor = PunctuationPredictor(model_dir='models/pun_models')
result = pun_predictor('近几年不但我用书给女儿儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书')
print(result)

Output:

[2022-01-13 15:27:11,194] [    INFO] - Found C:\Users\test\.paddlenlp\models\ernie-1.0\vocab.txt
近几年,不但我用书给女儿儿压岁,也劝说亲朋,不要给女儿压岁钱,而改送压岁书。

Source Code Address. The complete source code of this tool is as follows:

import json
import os
import re

import numpy as np
import paddle.inference as paddle_infer
from paddlenlp.transformers import ErnieTokenizer
from ppasr.utils.logger import setup_logger

logger = setup_logger(__name__)

__all__ = ['PunctuationPredictor']


class PunctuationPredictor:
    def __init__(self, model_dir, use_gpu=True, gpu_mem=500, num_threads=4):
        # Create configuration
        model_path = os.path.join(model_dir, 'model.pdmodel')
        params_path = os.path.join(model_dir, 'model.pdiparams')
        if not os.path.exists(model_path) or not os.path.exists(params_path):
            raise Exception("Punctuation model files do not exist. Please check if {} and {} exist!".format(model_path, params_path))
        self.config = paddle_infer.Config(model_path, params_path)
        # Determine the pre-trained model type
        pretrained_token = 'ernie-1.0'
        if os.path.exists(os.path.join(model_dir, 'info.json')):
            with open(os.path.join(model_dir, 'info.json'), 'r', encoding='utf-8') as f:
                data = json.load(f)
                pretrained_token = data['pretrained_token']

        if use_gpu:
            self.config.enable_use_gpu(gpu_mem, 0)
        else:
            self.config.disable_gpu()
            self.config.set_cpu_math_library_num_threads(num_threads)
        # Enable memory optimization
        self.config.enable_memory_optim()
        self.config.disable_glog_info()

        # Create predictor based on configuration
        self.predictor = paddle_infer.create_predictor(self.config)

        # Get input handles
        self.input_ids_handle = self.predictor.get_input_handle('input_ids')
        self.token_type_ids_handle = self.predictor.get_input_handle('token_type_ids')

        # Get output names
        self.output_names = self.predictor.get_output_names()

        self._punc_list = []
        vocab_file = os.path.join(model_dir, 'vocab.txt')
        if not os.path.exists(vocab_file):
            raise Exception("Dictionary file does not exist. Please check if {} exists!".format(vocab_file))
        with open(vocab_file, 'r', encoding='utf-8') as f:
            for line in f:
                self._punc_list.append(line.strip())

        self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)

        # Warm-up the model
        self('近几年不但我用书给女儿儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书')
        logger.info('Punctuation model loaded successfully.')

    def _clean_text(self, text):
        text = text.lower()
        text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
        text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '', text)
        return text

    # Preprocess the input text
    def preprocess(self, text: str):
        clean_text = self._clean_text(text)
        if len(clean_text) == 0:
            return None
        tokenized_input = self.tokenizer(list(clean_text), return_length=True, is_split_into_words=True)
        input_ids = tokenized_input['input_ids']
        seg_ids = tokenized_input['token_type_ids']
        seq_len = tokenized_input['seq_len']
        return input_ids, seg_ids, seq_len

    def infer(self, input_ids: list, seg_ids: list):
        # Set input data
        self.input_ids_handle.reshape([1, len(input_ids)])
        self.token_type_ids_handle.reshape([1, len(seg_ids)])
        self.input_ids_handle.copy_from_cpu(np.array([input_ids]).astype('int64'))
        self.token_type_ids_handle.copy_from_cpu(np.array([seg_ids]).astype('int64'))

        # Run the predictor
        self.predictor.run()

        # Get output
        output_handle = self.predictor.get_output_handle(self.output_names[0])
        output_data = output_handle.copy_to_cpu()
        return output_data

    # Postprocess the inference result
    def postprocess(self, input_ids, seq_len, preds):
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids[1:seq_len - 1])
        labels = preds[1:seq_len - 1].tolist()
        assert len(tokens) == len(labels)

        text = ''
        for t, l in zip(tokens, labels):
            text += t
            if l != 0:
                text += self._punc_list[l]
        return text

    def __call__(self, text: str) -> str:
        # Batch processing
        try:
            input_ids, seg_ids, seq_len = self.preprocess(text)
            preds = self.infer(input_ids=input_ids, seg_ids=seg_ids)
            if len(preds.shape) == 2:
                preds = preds[0]
            text = self.postprocess(input_ids, seq_len, preds)
        except Exception as e:
            logger.error(e)
        return text

References

  1. https://github.com/yeyupiaoling/PPASR
  2. https://github.com/PaddlePaddle/PaddleSpeech
Xiaoye