Google STT + Langchain

Willow·2024년 4월 29일
0

SPEECH PROCESSING

목록 보기
12/13

구글 해커톤을 하면서 오랜만에 써본 구글 STT API. 구글 공식 사이트에 적힌 내용을 그대로 따라갔다.

  1. 로컬 파일 (비동기)
import os

from google.cloud import speech

#NOTE: credential path 는 ~/.bashrc에 추가해준다

def transcribe_audio(file_path):
    # 클라이언트 초기화
    client = speech.SpeechClient()

    with open(file_path, "rb") as audio_file:
        content = audio_file.read()
        audio = speech.RecognitionAudio(content=content)

	# 오디오 config 설정은 파일에 맞게 지정해준다. 여기선 16Khz, 한국어, PCM, mono
    config = speech.RecognitionConfig(
        encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=16000,
        language_code="ko-KR",
        audio_channel_count=1,
    )

    # STT 시작
    response = client.recognize(config=config, audio=audio)

    for result in response.results:
        print("결과: {}".format(result.alternatives[0].transcript))


audio_file_path = "./path_to_file.wav"
transcribe_audio(audio_file_path)

+ 로컬로 할 때 파일 길이가 60s 를 넘어가면 에러가 난다. 긴 오디오 파일은 인식하고 싶다면 Google Cloud 내 버킷에 오디오를 저장해서 사용해야 하며, 긴 오디오에 대응되는 다른 함수를 써야 한다.

import os

from google.cloud import speech

def transcribe_audio(file_path):
    client = speech.SpeechClient()
    audio = speech.RecognitionAudio(uri=file_path)
    config = speech.RecognitionConfig(
        encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=16000,
        language_code="ko-KR",
        audio_channel_count=1,
    )

	# 긴 오디오 파일에 대응되는 함수를 사용
    operation = client.long_running_recognize(config=config, audio=audio)
    response = operation.result(timeout=90)

    # 오디오 전체 결과를 얻으려면 result를 다 봐야한다. 
    for result in response.results:
        print("결과: {}".format(result.alternatives[0].transcript))

# Google Cloud Bucket에 파일을 저장해주자!
gcs_uri = "gs://project_name/bucket_directory/path_to_file.wav"
transcribe_audio(gcs_uri)
  1. 마이크 (스트리밍)
import os
import re
import sys

import pyaudio
from google.cloud import speech
from six.moves import queue

# 오디오 설정들
RATE = 16000
CHUNK = int(RATE / 10)  # 100ms

# 마이크로 받는 음성
class MicrophoneStream(object):
    def __init__(self, rate, chunk):
        self._rate = rate
        self._chunk = chunk
        self._buff = queue.Queue()
        self.closed = True

    def __enter__(self):
        self._audio_interface = pyaudio.PyAudio()
        self._audio_stream = self._audio_interface.open(
            format=pyaudio.paInt16,
            channels=1,
            rate=self._rate,
            input=True,
            frames_per_buffer=self._chunk,
            stream_callback=self._fill_buffer,
        )

        self.closed = False

        return self

    def __exit__(self, type, value, traceback):
        self._audio_stream.stop_stream()
        self._audio_stream.close()
        self.closed = True
        self._buff.put(None)
        self._audio_interface.terminate()

    def _fill_buffer(self, in_data, frame_count, time_info, status_flags):
        self._buff.put(in_data)
        return None, pyaudio.paContinue

    def generator(self):
        while not self.closed:
            # Use a blocking get() to ensure there's at least one chunk of
            # data, and stop iteration if the chunk is None, indicating the
            # end of the audio stream.
            chunk = self._buff.get()
            if chunk is None:
                return
            data = [chunk]

            # Now consume whatever other data's still buffered.
            while True:
                try:
                    chunk = self._buff.get(block=False)
                    if chunk is None:
                        return
                    data.append(chunk)
                except queue.Empty:
                    break

            yield b"".join(data)


def listen_print_loop(responses):
    num_chars_printed = 0
    for response in responses:
        if not response.results:
            continue

        
        # 첫번째 result만 필요
        result = response.results[0]
        if not result.alternatives:
            continue

        # 결과 출력
        transcript = result.alternatives[0].transcript

        # 실시간으로 결과 변경
        overwrite_chars = " " * (num_chars_printed - len(transcript))

        if not result.is_final:
            sys.stdout.write(transcript + overwrite_chars + "\r")
            sys.stdout.flush()

            num_chars_printed = len(transcript)

        else:
            print(transcript + overwrite_chars)
            break


# 스트리밍 STT
def main():
    client = speech.SpeechClient()
    config = speech.RecognitionConfig(
        encoding=speech.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=RATE,
        language_code="ko-KR",
    )
    streaming_config = speech.StreamingRecognitionConfig(
        config=config, interim_results=True
    )

    with MicrophoneStream(RATE, CHUNK) as stream:
        audio_generator = stream.generator()
        requests = (
            speech.StreamingRecognizeRequest(audio_content=content)
            for content in audio_generator
        )

        responses = client.streaming_recognize(streaming_config, requests)

        # 출력
        listen_print_loop(responses)
  1. Langchain으로 래핑해서 쓸 때

다른 API나 라이브러리들도 그렇지만, 세세하게 커스텀할 땐 결국 langchain 보단 직접 호출하는 게 더 편한 것 같다. 지금까지 써본 바로는, 이것저것 다 간단하게 구현하거나 각 API들 성능을 테스트해볼 때 langchain 이 유용해 보인다.

from google.cloud import storage
from google.cloud.speech_v2 import (
    AutoDetectDecodingConfig,
    RecognitionConfig,
    RecognitionFeatures,
    SpeechClient,
)
from google.cloud.speech_v2.types import cloud_speech
from langchain_community.document_loaders import GoogleSpeechToTextLoader

config = RecognitionConfig(
    auto_decoding_config=AutoDetectDecodingConfig(),
    language_codes=["ko-KR"],
    model="long",  # chirp
    features=RecognitionFeatures(
        enable_automatic_punctuation=True,
        enable_spoken_punctuation=True,
    ),
)


project_id = "project_id_num"
file_path = "gs://project_name/bucket-directory/path_to_file.wav"
location = "global"

loader = GoogleSpeechToTextLoader(
    project_id=project_id,
    file_path=file_path,
    config=config,
    location=location,
)
docs = loader.load()
print(docs[0].page_content)
profile
Speech Processing/AI/Linguistics/CS/etc.

0개의 댓글