Source code for kani.ext.multimodal_core.audio

"""Core MessageParts for Kani multimodal"""

import base64
import hashlib
import io
import warnings
import wave
from typing import IO, TYPE_CHECKING

import numpy as np
from kani.utils import saveload
from kani.utils.typing import PathLike
from pydantic import Field, model_serializer, model_validator
from pydub import AudioSegment

from .base import BaseMultimodalPart
from .utils import download_media

if TYPE_CHECKING:
    import torch


[docs] class AudioPart(BaseMultimodalPart): """ A part representing audio data. Audio data is stored in memory as raw signed 16-bit little-endian mono PCM in :attr:`raw`, at a variable :attr:`sample_rate`. When serialized, audio data is represented as a data URI. To get audio data in a suitable format for downstream applications, use :meth:`as_b64`, :meth:`as_bytes`, :meth:`as_ndarray`, or :meth:`as_tensor`. """ raw: bytes = Field(repr=False) """The raw binary data in signed 16-bit little-endian mono PCM format.""" sample_rate: int """The sample rate of the binary data.""" # ==== constructors ====
[docs] @classmethod def from_b64(cls, data: str, sr: int, **kwargs): """Create an AudioPart from Base64-encoded signed 16-bit little-endian mono PCM data.""" return cls(raw=base64.b64decode(data), sample_rate=sr, **kwargs)
[docs] @classmethod def from_file( cls, fp: PathLike | IO, *, format: str = None, codec: str = None, converter_parameters: str = None, sr: int = None, sample_width: int = None, channels: int = None, **kwargs, ): """ Create an AudioPart from a local file. :param fp: The path to the file or an open file to read. :param format: The format (e.g. 'mp3') of the audio file. Will attempt to automatically determine based on the given filename if this is not set. :param codec: An explicit audio codec to use to decode the audio file, if conversion is needed. (See FFMPEG's ``-acodec`` option for valid inputs). :param converter_parameters: Any additional CLI arguments to pass to the audio converter, if conversion is needed. :param sr: The sample rate of the audio (raw PCM audio only). :param sample_width: The sample width, in bytes, of the audio (raw PCM audio only). :param channels: The number of channels of the audio (raw PCM audio only). """ segment = AudioSegment.from_file( fp, format=format, codec=codec, parameters=converter_parameters, frame_rate=sr, sample_width=sample_width, channels=channels, ) mono = segment.set_channels(1).set_sample_width(2) return cls(raw=mono.raw_data, sample_rate=mono.frame_rate, **kwargs)
@classmethod def from_wav_b64_uri(cls, data: str): if not data.startswith("data:audio/wav;base64,"): raise ValueError("Data URI must begin with `data:audio/wav;base64,`") wav_bytes = base64.b64decode(data.removeprefix("data:audio/wav;base64,")) return cls.from_file(io.BytesIO(wav_bytes), format="wav")
[docs] @classmethod async def from_url(cls, url: str, **kwargs): """ Download audio from the Internet and create an AudioPart. .. attention:: Note that this classmethod is *asynchronous*, as it downloads data from the web! Keyword arguments are passed to :meth:`from_file`. """ f = io.BytesIO() await download_media(url, f, allowed_mime=("audio/*",)) return cls.from_file(f, **kwargs)
# ==== representations ==== # --- raw --- def _as_bytes(self, sr: int = None) -> bytes: # this is a private method so the warning's stacklevel is correct if sr is None: warnings.warn( "AudioPart.as_bytes() was called with no explicit sample rate given. Returning at the raw sample rate" f" of {self.sample_rate}! Pass `sr=...` or use `.resample()` to use a different rate.", stacklevel=3, ) return self.raw if sr == self.sample_rate: return self.raw # sample to the specified sr and return segment = AudioSegment(self.raw, sample_width=2, frame_rate=self.sample_rate, channels=1) return segment.set_frame_rate(sr).raw_data
[docs] def as_bytes(self, sr: int = None) -> bytes: """Return the audio data as signed 16-bit little-endian mono PCM at the given sample rate.""" return self._as_bytes(sr)
[docs] def as_b64(self, sr: int = None) -> str: """Return the audio data as Base64-encoded signed 16-bit little-endian mono PCM at the given sample rate.""" return base64.b64encode(self._as_bytes(sr)).decode()
[docs] def as_ndarray(self, sr: int = None) -> np.ndarray: """Return the audio data as a 1-dimensional NumPy array of floats at the given sample rate.""" # equivalence verify # $ ffmpeg -i test.mp3 -ac 1 -ar 24000 test.wav # $ ffmpeg -i test.mp3 -f s16le -acodec pcm_s16le -ac 1 -ar 24000 test.pcm # import soundfile, numpy as np # from pathlib import Path # audio_path_wav = Path("test.wav") # audio_path_pcm = Path("test.pcm") # audio_wav, sr = soundfile.read(audio_path_wav) # audio_bytes = audio_path_pcm.read_bytes() # audio_ints = np.frombuffer(audio_bytes, dtype=np.int16) # audio_wav2 = audio_ints / 32768 # (audio_wav == audio_wav2).all() audio_ints = np.frombuffer(self._as_bytes(sr), dtype=np.int16) return audio_ints / 32768
[docs] def as_tensor(self, sr: int = None) -> "torch.Tensor": """ Return the audio data as a 2-dimensional [channel, time] PyTorch Tensor of floats at the given sample rate. Note that since this library only uses mono audio, that the first dimension will always be 1. """ # equivalence verify # $ ffmpeg -i test.mp3 -ac 1 -ar 24000 test.wav # $ ffmpeg -i test.mp3 -f s16le -acodec pcm_s16le -ac 1 -ar 24000 test.pcm # import torchaudio, torch # from pathlib import Path # audio_path_pcm = Path("test.pcm") # audio_wav, sr = torchaudio.load("test.wav") # audio_bytes = audio_path_pcm.read_bytes() # audio_ints = torch.frombuffer(audio_bytes, dtype=torch.int16) # audio_wav2 = audio_ints.div(32768).reshape(1, -1) # (audio_wav == audio_wav2).all() try: import torch except ImportError: raise ImportError( "PyTorch is not installed in your environment. Please install `torch` to use `.as_tensor`." ) from None audio_ints = torch.frombuffer(self._as_bytes(sr), dtype=torch.int16) return audio_ints.div(32768).reshape(1, -1)
# --- WAV ---
[docs] def as_wav_bytes(self) -> bytes: """Return the audio data as WAV data (including header).""" out_bytes = io.BytesIO() with wave.open(out_bytes, "wb") as wave_data: wave_data.setnchannels(1) wave_data.setsampwidth(2) wave_data.setframerate(self.sample_rate) wave_data.setnframes(len(self.raw) // 2) wave_data.writeframesraw(self.raw) out_bytes.seek(0) return out_bytes.getvalue()
[docs] def as_wav_b64_uri(self) -> str: """Return the WAV audio data encoded in a web-suitable base64 string.""" wav_b64 = base64.b64encode(self.as_wav_bytes()).decode() return f"data:audio/wav;base64,{wav_b64}"
# ==== helpers ====
[docs] def resample(self, sample_rate: int) -> "AudioPart": """Return a new AudioPart with the given sample rate.""" if sample_rate == self.sample_rate: return self return AudioPart(raw=self._as_bytes(sr=sample_rate), sample_rate=sample_rate)
[docs] def sha256(self) -> bytes: """Return the SHA-256 hash of the raw audio.""" return hashlib.sha256(self.raw).digest()
@property def duration(self) -> float: """The duration of this audio clip, in seconds.""" # 16b mono -> 2 bytes per sample * sample rate return len(self.raw) / (self.sample_rate * 2) @property def sr(self): """An alias to :attr:`sample_rate`.""" return self.sample_rate @sr.setter def sr(self, value): self.sample_rate = value def __repr__(self): audio_repr = f"[audio: {self.duration:.3f}s]" return f'{self.__repr_name__()}({self.__repr_str__(", ")}, raw={audio_repr})' def __rich_repr__(self): audio_repr = f"[audio: {self.duration:.3f}s]" yield "raw", audio_repr # ==== serdes ==== @model_serializer() def _serialize_audiopart(self, info) -> dict[str, str]: """ When we serialize to JSON, save the data as: - a URI when not in zipfile mode - a WAV file when in zipfile mode """ if ctx := saveload.get_ctx(info): fp = ctx.save_bytes(self.as_wav_bytes(), suffix=".wav") return {"_archive_path": fp, **self._get_typekey_dict()} else: return {"wav_data": self.as_wav_b64_uri(), **self._get_typekey_dict()} # noinspection PyNestedDecorators @model_validator(mode="wrap") @classmethod def _validate_audiopart(cls, v, nxt, info): """If the value is the URI we saved, try loading it that way""" assert isinstance(v, dict) if "_archive_path" in v: ctx = saveload.get_ctx(info) wav_data = ctx.load_bytes(v["_archive_path"]) return cls.from_file(io.BytesIO(wav_data), format="wav") elif "wav_data" in v: return cls.from_wav_b64_uri(v["wav_data"]) return nxt(v)