Preface¶
In speech recognition, the model outputs only a plain text result without adding punctuation marks according to grammar. This tutorial addresses this issue by adding punctuation marks to speech recognition text based on grammatical context, enabling the speech recognition system to output the final result with punctuation.
Note: This tutorial only introduces usage. For training your own model, please refer to the article “Training a Chinese Punctuation Model Based on PaddlePaddle” (https://blog.csdn.net/qq_33200967/article/details/126858763).
Usage¶
The usage mainly consists of 4 steps:
- First, download the model and extract it to the
models/directory. The download link is as follows:
https://download.csdn.net/download/qq_33200967/75664996
- The PaddleNLP tool is required, so install it in advance using the following command:
python -m pip install paddlenlp -i https://mirrors.aliyun.com/pypi/simple/
- PPASR has provided a tool for automatic punctuation addition starting from version 0.1.3. You can install
ppasrto use this tool:
python -m pip install ppasr==2.4.8 -i https://mirrors.aliyun.com/pypi/simple/
- Adding punctuation to text automatically is very simple. Note that the import path and name have changed after version V2.
from ppasr.infer_utils.pun_predictor import PunctuationPredictor
pun_predictor = PunctuationPredictor(model_dir='models/pun_models')
result = pun_predictor('近几年不但我用书给女儿儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书')
print(result)
Output:
[2022-01-13 15:27:11,194] [ INFO] - Found C:\Users\test\.paddlenlp\models\ernie-1.0\vocab.txt
近几年,不但我用书给女儿儿压岁,也劝说亲朋,不要给女儿压岁钱,而改送压岁书。
Source Code Address. The complete source code of this tool is as follows:
import json
import os
import re
import numpy as np
import paddle.inference as paddle_infer
from paddlenlp.transformers import ErnieTokenizer
from ppasr.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['PunctuationPredictor']
class PunctuationPredictor:
def __init__(self, model_dir, use_gpu=True, gpu_mem=500, num_threads=4):
# Create configuration
model_path = os.path.join(model_dir, 'model.pdmodel')
params_path = os.path.join(model_dir, 'model.pdiparams')
if not os.path.exists(model_path) or not os.path.exists(params_path):
raise Exception("Punctuation model files do not exist. Please check if {} and {} exist!".format(model_path, params_path))
self.config = paddle_infer.Config(model_path, params_path)
# Determine the pre-trained model type
pretrained_token = 'ernie-1.0'
if os.path.exists(os.path.join(model_dir, 'info.json')):
with open(os.path.join(model_dir, 'info.json'), 'r', encoding='utf-8') as f:
data = json.load(f)
pretrained_token = data['pretrained_token']
if use_gpu:
self.config.enable_use_gpu(gpu_mem, 0)
else:
self.config.disable_gpu()
self.config.set_cpu_math_library_num_threads(num_threads)
# Enable memory optimization
self.config.enable_memory_optim()
self.config.disable_glog_info()
# Create predictor based on configuration
self.predictor = paddle_infer.create_predictor(self.config)
# Get input handles
self.input_ids_handle = self.predictor.get_input_handle('input_ids')
self.token_type_ids_handle = self.predictor.get_input_handle('token_type_ids')
# Get output names
self.output_names = self.predictor.get_output_names()
self._punc_list = []
vocab_file = os.path.join(model_dir, 'vocab.txt')
if not os.path.exists(vocab_file):
raise Exception("Dictionary file does not exist. Please check if {} exists!".format(vocab_file))
with open(vocab_file, 'r', encoding='utf-8') as f:
for line in f:
self._punc_list.append(line.strip())
self.tokenizer = ErnieTokenizer.from_pretrained(pretrained_token)
# Warm-up the model
self('近几年不但我用书给女儿儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书')
logger.info('Punctuation model loaded successfully.')
def _clean_text(self, text):
text = text.lower()
text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '', text)
return text
# Preprocess the input text
def preprocess(self, text: str):
clean_text = self._clean_text(text)
if len(clean_text) == 0:
return None
tokenized_input = self.tokenizer(list(clean_text), return_length=True, is_split_into_words=True)
input_ids = tokenized_input['input_ids']
seg_ids = tokenized_input['token_type_ids']
seq_len = tokenized_input['seq_len']
return input_ids, seg_ids, seq_len
def infer(self, input_ids: list, seg_ids: list):
# Set input data
self.input_ids_handle.reshape([1, len(input_ids)])
self.token_type_ids_handle.reshape([1, len(seg_ids)])
self.input_ids_handle.copy_from_cpu(np.array([input_ids]).astype('int64'))
self.token_type_ids_handle.copy_from_cpu(np.array([seg_ids]).astype('int64'))
# Run the predictor
self.predictor.run()
# Get output
output_handle = self.predictor.get_output_handle(self.output_names[0])
output_data = output_handle.copy_to_cpu()
return output_data
# Postprocess the inference result
def postprocess(self, input_ids, seq_len, preds):
tokens = self.tokenizer.convert_ids_to_tokens(input_ids[1:seq_len - 1])
labels = preds[1:seq_len - 1].tolist()
assert len(tokens) == len(labels)
text = ''
for t, l in zip(tokens, labels):
text += t
if l != 0:
text += self._punc_list[l]
return text
def __call__(self, text: str) -> str:
# Batch processing
try:
input_ids, seg_ids, seq_len = self.preprocess(text)
preds = self.infer(input_ids=input_ids, seg_ids=seg_ids)
if len(preds.shape) == 2:
preds = preds[0]
text = self.postprocess(input_ids, seq_len, preds)
except Exception as e:
logger.error(e)
return text