Vita/examples/multiprocessing_presets/main.py
2025-02-13 14:21:57 -05:00

183 lines
6.1 KiB
Python

# This file is part of the Vita distribution (https://github.com/DBraun/Vita).
# Copyright (c) 2025 David Braun.
import argparse
from collections import namedtuple
import logging
import multiprocessing
import os
from pathlib import Path
import time
import traceback
# extra libraries to install with pip
import vita
import numpy as np
from scipy.io import wavfile
from tqdm import tqdm
Item = namedtuple("Item", "preset_path")
class Worker:
def __init__(
self,
queue: multiprocessing.Queue,
bpm: float = 120.0,
note_duration: float = 2.0,
render_duration: float = 5.0,
pitch_low: int = 60,
pitch_high: int = 72,
velocity: int = 100,
output_dir="output",
):
self.queue = queue
self.bpm = bpm
self.note_duration = note_duration
self.render_duration = render_duration
self.pitch_low, self.pitch_high = pitch_low, pitch_high
self.velocity = velocity
self.output_dir = Path(output_dir)
def startup(self):
synth = vita.Synth()
synth.set_bpm(self.bpm)
self.synth = synth
def process_item(self, item: Item):
preset_path = item.preset_path
self.synth.load_preset(preset_path)
basename = os.path.basename(preset_path)
for pitch in range(self.pitch_low, self.pitch_high + 1):
audio = self.synth.render(
pitch, self.velocity, self.note_duration, self.note_duration
)
output_path = self.output_dir / f"{pitch}_{basename}.wav"
wavfile.write(str(output_path), 44_100, audio.transpose())
def run(self):
try:
self.startup()
while True:
try:
item = self.queue.get_nowait()
self.process_item(item)
except multiprocessing.queues.Empty:
break
except Exception as e:
return traceback.format_exc()
def main(
preset_dir,
bpm: float = 120.0,
note_duration: float = 2.0,
render_duration: float = 4.0,
pitch_low: int = 60,
pitch_high: int = 60,
num_workers=None,
output_dir="output",
logging_level="INFO",
):
# Create logger
logging.basicConfig()
logger = logging.getLogger("vita")
logger.setLevel(logging_level.upper())
# Glob all the preset file paths
preset_paths = list(Path(preset_dir).rglob("*.vital"))
# Get num items so that the progress bar works well
num_items = len(preset_paths)
# Create a Queue and add items
input_queue = multiprocessing.Manager().Queue()
for preset_path in preset_paths:
input_queue.put(Item(str(preset_path)))
# Create a list to hold the worker processes
workers = []
# The number of workers to spawn
num_processes = num_workers or multiprocessing.cpu_count()
# Log info
logger.info(f"Note duration: {note_duration}")
logger.info(f"Render duration: {render_duration}")
logger.info(f"Using num workers: {num_processes}")
logger.info(f"Pitch low: {pitch_low}")
logger.info(f"Pitch high: {pitch_high}")
logger.info(f"Output directory: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
# Create a multiprocessing Pool
with multiprocessing.Pool(processes=num_processes) as pool:
# Create and start a worker process for each CPU
for i in range(num_processes):
worker = Worker(
input_queue,
bpm=bpm,
note_duration=note_duration,
render_duration=render_duration,
pitch_low=pitch_low,
pitch_high=pitch_high,
output_dir=output_dir,
)
async_result = pool.apply_async(worker.run)
workers.append(async_result)
# Use tqdm to track progress. Update the progress bar in each iteration.
pbar = tqdm(total=num_items)
while True:
incomplete_count = sum(1 for w in workers if not w.ready())
pbar.update(
num_items - input_queue.qsize() - pbar.n
) # not perfectly accurate.
if incomplete_count == 0:
break
time.sleep(0.1)
pbar.close()
# Check for exceptions in the worker processes
for i, worker in enumerate(workers):
exception = worker.get()
if exception is not None:
logger.error(f"Exception in worker {i}:\n{exception}")
logger.info("All done!")
if __name__ == "__main__":
# We're using multiprocessing.Pool, so our code MUST be inside __main__.
# See https://docs.python.org/3/library/multiprocessing.html
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--preset-dir", required=True, help="Directory path of Vital presets.")
parser.add_argument("--bpm", default=120.0, type=float, help="Beats per minute for the Render Engine.")
parser.add_argument("--note-duration", default=1, type=float, help="Note duration in seconds.")
parser.add_argument("--pitch-low", default=60, type=int, help="Lowest MIDI pitch to be used (inclusive).")
parser.add_argument("--pitch-high", default=60, type=int, help="Highest MIDI pitch to be used (inclusive).")
parser.add_argument("--render-duration", default=1, type=float, help="Render duration in seconds.")
parser.add_argument("--num-workers", default=None, type=int, help="Number of workers to use.")
parser.add_argument("--output-dir", default=os.path.join(os.path.dirname(__file__), "output"), help="Output directory.")
parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "NOTSET"], help="Logger level.")
# fmt: on
args = parser.parse_args()
main(
args.preset_dir,
args.bpm,
args.note_duration,
args.render_duration,
args.pitch_low,
args.pitch_high,
args.num_workers,
args.output_dir,
args.log_level,
)