Minor API fixes for OH3
[voicecontrol.git] / vosk-server
1 #!/usr/bin/env python3
2
3 import json
4 import os
5 import sys
6 import asyncio
7 import pathlib
8 import websockets
9 import concurrent.futures
10 import logging
11 from vosk import Model, KaldiRecognizer
12
13 from pprint import pprint
14
15 # Enable loging if needed
16 #
17 # logger = logging.getLogger('websockets')
18 # logger.setLevel(logging.INFO)
19 # logger.addHandler(logging.StreamHandler())
20
21 vosk_interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0')
22 vosk_port = int(os.environ.get('VOSK_SERVER_PORT', 2700))
23 vosk_model_path = os.environ.get('VOSK_MODEL_PATH', 'model')
24 vosk_sample_rate = float(os.environ.get('VOSK_SAMPLE_RATE', 8000))
25
26 if len(sys.argv) > 1:
27    vosk_model_path = sys.argv[1]
28
29 # Gpu part, uncomment if vosk-api has gpu support
30 #
31 # from vosk import GpuInit, GpuInstantiate
32 # GpuInit()
33 # def thread_init():
34 #     GpuInstantiate()
35 # pool = concurrent.futures.ThreadPoolExecutor(initializer=thread_init)
36
37 model = Model(vosk_model_path)
38 pool = concurrent.futures.ThreadPoolExecutor((os.cpu_count() or 1))
39 loop = asyncio.get_event_loop()
40
41 def process_chunk(rec, message):
42     if message == '{"eof" : 1}':
43         return rec.FinalResult(), True
44     elif rec.AcceptWaveform(message):
45         return rec.Result(), False
46     else:
47         return rec.PartialResult(), False
48
49 async def recognize(websocket, path):
50
51     rec = None
52     phrase_list = None
53     sample_rate = vosk_sample_rate
54
55     while True:
56
57         message = await websocket.recv()
58
59         # Load configuration if provided
60         if isinstance(message, str) and 'config' in message:
61             jobj = json.loads(message)['config']
62             if 'phrase_list' in jobj:
63                 phrase_list = jobj['phrase_list']
64             if 'sample_rate' in jobj:
65                 sample_rate = float(jobj['sample_rate'])
66             continue
67
68         # Create the recognizer, word list is temporary disabled since not every model supports it
69         if not rec:
70             if phrase_list:
71                  rec = KaldiRecognizer(model, sample_rate, json.dumps(phrase_list, ensure_ascii=False))
72             else:
73                  rec = KaldiRecognizer(model, sample_rate)
74
75         response, stop = await loop.run_in_executor(pool, process_chunk, rec, message)
76         await websocket.send(response)
77         if stop: break
78
79 start_server = websockets.serve(
80     recognize, vosk_interface, vosk_port)
81
82 loop.run_until_complete(start_server)
83 loop.run_forever()