Preface¶
This article mainly introduces how to quickly train and infer using the MASR speech recognition framework. This article will focus on the simplest way to introduce its usage. For more advanced features, you need to refer to the documentation from the source code. Only three lines of code are needed to implement training and inference.
Source Code Address: https://github.com/yeyupiaoling/MASR
Installation Environment¶
Using Anaconda, a virtual environment with Python 3.11 has been created.
- First, install the GPU version of PyTorch 2.5.1. If it has already been installed, skip this step.
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia
- Install the MASR library using pip with the following command:
python -m pip install masr -U -i https://pypi.tuna.tsinghua.edu.cn/simple
Prepare Dataset¶
Execute the following code to automatically download the data and create the data list. The default download may be slow. You can copy the download address and use tools like Thunder to download it, then specify the filepath as the path to the downloaded file to quickly complete the data list creation.
import argparse
import os
import functools
from utility import download, unpack
from utility import add_arguments, print_arguments
DATA_URL = 'https://openslr.trmal.net/resources/33/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg("target_dir", default="dataset/audio/", type=str, help="Directory to store audio files")
add_arg("annotation_text", default="dataset/annotation/", type=str, help="Directory to store audio annotation files")
add_arg("filepath", default=None, type=str, help="Pre-downloaded dataset compressed file")
args = parser.parse_args()
def create_annotation_text(data_dir, annotation_path):
print('Create Aishell annotation text ...')
if not os.path.exists(annotation_path):
os.makedirs(annotation_path)
f_train = open(os.path.join(annotation_path, 'aishell.txt'), 'w', encoding='utf-8')
if not os.path.exists(os.path.join(annotation_path, 'test.txt')):
f_test = open(os.path.join(annotation_path, 'test.txt'), 'w', encoding='utf-8')
else:
f_test = open(os.path.join(annotation_path, 'test.txt'), 'a', encoding='utf-8')
transcript_path = os.path.join(data_dir, 'transcript', 'aishell_transcript_v0.8.txt')
transcript_dict = {}
for line in open(transcript_path, 'r', encoding='utf-8'):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# Remove spaces
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev']
for type in data_types:
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname).replace('\\', '/')
audio_id = fname[:-4]
# Skip if no transcription for the audio
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
f_train.write(audio_path.replace('../', '') + '\t' + text + '\n')
audio_dir = os.path.join(data_dir, 'wav', 'test')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname).replace('\\', '/')
audio_id = fname[:-4]
# Skip if no transcription for the audio
if audio_id not in transcript_dict:
continue
text = transcript_dict[audio_id]
f_test.write(audio_path.replace('../', '') + '\t' + text + '\n')
f_test.close()
f_train.close()
def prepare_dataset(url, md5sum, target_dir, annotation_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
if args.filepath is None:
filepath = download(url, md5sum, target_dir)
else:
filepath = args.filepath
unpack(filepath, target_dir)
# Unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
os.remove(filepath)
else:
print("Skip downloading and unpacking. Aishell data already exists in %s." % target_dir)
create_annotation_text(data_dir, annotation_path)
def main():
print_arguments(args)
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
annotation_path=args.annotation_text)
if __name__ == '__main__':
main()
Training¶
Training with the MASR framework is very simple. The core code only requires three lines, as follows. The configs parameter can specify the default configuration file to use.
from masr.trainer import MASRTrainer
trainer = MASRTrainer(configs="conformer", use_gpu=True)
trainer.train(save_model_path="models/")
The output will be similar to the following:
2025-03-08 11:04:57.884 | INFO | masr.optimizer:build_optimizer:16 - Successfully created optimizer: Adam, parameters: {'lr': 0.001, 'weight_decay': 1e-06}
2025-03-08 11:04:57.884 | INFO | masr.optimizer:build_lr_scheduler:31 - Successfully created learning rate scheduler: WarmupLR, parameters: {'warmup_steps': 25000, 'min_lr': 1e-05}
2025-03-08 11:04:57.885 | INFO | masr.trainer:train:541 - Vocabulary size: 5561
2025-03-08 11:04:57.885 | INFO | masr.trainer:train:542 - Training data: 13382
2025-03-08 11:04:57.885 | INFO | masr.trainer:train:543 - Evaluation data: 27
2025-03-08 11:04:58.642 | INFO | masr.trainer:__train_epoch:414 - Train epoch: [1/200], batch: [0/836], loss: 51.60880, learning_rate: 0.00000008, reader_cost: 0.1062, batch_cost: 0.6486, ips: 21.1991 speech/sec, eta: 1 day, 11:03:13
Model Export¶
After training is complete, the model needs to be exported for inference. Exporting the model is also very simple. Three lines of code are needed:
from masr.trainer import MASRTrainer
# Get the trainer
trainer = MASRTrainer(configs="conformer", use_gpu=True)
# Export the prediction model
trainer.export(save_model_path='models/',
resume_model='models/ConformerModel_fbank/best_model/')
Inference¶
Inference is also quite simple and can be completed with the following three lines of code:
from masr.predict import MASRPredictor
predictor = MASRPredictor(model_dir="models/ConformerModel_fbank/inference_model/", use_gpu=True)
audio_path = "dataset/test.wav"
result = predictor.predict(audio_data=audio_path)
print(f"Recognition Result: {result}")
The output is as follows:
2025-03-08 11:21:52.100 | INFO | masr.infer_utils.inference_predictor:__init__:38 - Model loaded: models/ConformerModel_fbank/inference_model/inference.pth
2025-03-08 11:21:52.147 | INFO | masr.predict:__init__:117 - Stream VAD model loaded successfully
2025-03-08 11:21:52.147 | INFO | masr.predict:__init__:119 - Starting to warm up the predictor...
2025-03-08 11:22:01.366 | INFO | masr.predict:reset_predictor:471 - Resetting predictor
2025-03-08 11:22:01.366 | INFO | masr.predict:__init__:128 - Predictor ready!
Recognition Result: {'text': '近几年不但我用书给女儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书', 'sentences': [{'text': '近几年不但我用书给女儿压岁也劝说亲朋不要给女儿压岁钱而改送压岁书', 'start': 0, 'end': 8.39}]}
Conclusion¶
This framework supports multiple speech recognition models, including deepspeech2, conformer, squeezeformer, efficient_conformer, etc. Each model supports both streaming and non-streaming recognition, as well as multiple decoders, including ctc_greedy_search, ctc_prefix_beam_search, attention_rescoring, ctc_beam_search, etc. More features await your exploration.