9 import concurrent.futures
11 from vosk import Model, KaldiRecognizer
13 from pprint import pprint
15 # Enable loging if needed
17 # logger = logging.getLogger('websockets')
18 # logger.setLevel(logging.INFO)
19 # logger.addHandler(logging.StreamHandler())
21 vosk_interface = os.environ.get('VOSK_SERVER_INTERFACE', '0.0.0.0')
22 vosk_port = int(os.environ.get('VOSK_SERVER_PORT', 2700))
23 vosk_model_path = os.environ.get('VOSK_MODEL_PATH', 'model')
24 vosk_sample_rate = float(os.environ.get('VOSK_SAMPLE_RATE', 8000))
27 vosk_model_path = sys.argv[1]
29 # Gpu part, uncomment if vosk-api has gpu support
31 # from vosk import GpuInit, GpuInstantiate
35 # pool = concurrent.futures.ThreadPoolExecutor(initializer=thread_init)
37 model = Model(vosk_model_path)
38 pool = concurrent.futures.ThreadPoolExecutor((os.cpu_count() or 1))
39 loop = asyncio.get_event_loop()
41 def process_chunk(rec, message):
42 if message == '{"eof" : 1}':
43 return rec.FinalResult(), True
44 elif rec.AcceptWaveform(message):
45 return rec.Result(), False
47 return rec.PartialResult(), False
49 async def recognize(websocket, path):
53 sample_rate = vosk_sample_rate
57 message = await websocket.recv()
59 # Load configuration if provided
60 if isinstance(message, str) and 'config' in message:
61 jobj = json.loads(message)['config']
62 if 'phrase_list' in jobj:
63 phrase_list = jobj['phrase_list']
64 if 'sample_rate' in jobj:
65 sample_rate = float(jobj['sample_rate'])
68 # Create the recognizer, word list is temporary disabled since not every model supports it
71 rec = KaldiRecognizer(model, sample_rate, json.dumps(phrase_list, ensure_ascii=False))
73 rec = KaldiRecognizer(model, sample_rate)
75 response, stop = await loop.run_in_executor(pool, process_chunk, rec, message)
76 await websocket.send(response)
79 start_server = websockets.serve(
80 recognize, vosk_interface, vosk_port)
82 loop.run_until_complete(start_server)