from fastapi import FastAPI, Form, HTTPException, BackgroundTasks
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from einops import rearrange
import torch
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import os
import uuid
import asyncio
import threading
app = FastAPI()
# Mount static directory for audio files
app.mount("/static", StaticFiles(directory="static"), name="static")
# Initialize model (default to small)
model_version = 'facebook/musicgen-small'
print("Loading model", model_version)
model = MusicGen.get_pretrained(model_version, device='cuda')
# Global flag for cancellation
cancel_generation = False
class MusicGenRequest(BaseModel):
descriptions: list[str]
duration: float
num_samples: int
model_version: str
def cleanup_files(file_paths: list[str]):
import time
time.sleep(3600) # Delete after 1 hour
for path in file_paths:
if os.path.exists(path):
os.remove(path)
@app.get("/", response_class=HTMLResponse)
async def get_form():
return """
<!DOCTYPE html>
<html>
<head>
<title>MusicGen Web Interface</title>
<style>
body { font-family: Arial, sans-serif; margin: 20px; }
.audio-container { margin: 20px 0; }
#descriptionInputs { margin-bottom: 20px; }
.description-field { margin-bottom: 10px; }
#progressBar { width: 100%; height: 20px; }
#history { max-height: 200px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; }
.sample-controls { display: flex; align-items: center; gap: 10px; }
</style>
</head>
<body>
<h1>MusicGen Audio Generator</h1>
<form id="musicForm">
<label>Model:</label>
<select name="model_version">
<option value="facebook/musicgen-small">Small</option>
<option value="facebook/musicgen-medium">Medium</option>
<option value="facebook/musicgen-large">Large</option>
</select><br>
<label>Number of Samples:</label>
<div class="sample-controls">
<button type="button" onclick="changeSampleCount(-1)">-</button>
<span id="num_samples_display">1</span>
<input type="hidden" id="num_samples" name="num_samples" value="1">
<button type="button" onclick="changeSampleCount(1)">+</button>
</div><br>
<div id="descriptionInputs">
<div class="description-field">
<label>Description 1:</label><br>
<textarea name="descriptions" rows="3" cols="50"></textarea>
</div>
</div>
<label>Duration (seconds):</label>
<input type="number" name="duration" value="30" min="1" max="300"><br>
<button type="submit">Generate Music</button>
<button type="button" id="cancelButton" style="display: none;" onclick="cancelGeneration()">Cancel</button>
</form>
<div id="progress" style="display: none;">
<label>Generation Progress:</label>
<progress id="progressBar" value="0" max="100"></progress>
<span id="progressText">0%</span>
</div>
<div id="results"></div>
<h2>History (Last 100 Descriptions)</h2>
<div id="history"></div>
<script>
function updateDescriptionFields() {
const numSamples = parseInt(document.getElementById("num_samples").value);
document.getElementById("num_samples_display").textContent = numSamples;
const container = document.getElementById("descriptionInputs");
container.innerHTML = "";
for (let i = 0; i < numSamples; i++) {
container.innerHTML += `
<div class="description-field">
<label>Description ${i + 1}:</label><br>
<textarea name="descriptions" rows="3" cols="50"></textarea>
</div>`;
}
}
function changeSampleCount(delta) {
const input = document.getElementById("num_samples");
let numSamples = parseInt(input.value) + delta;
numSamples = Math.max(1, Math.min(10, numSamples)); // Clamp between 1 and 10
input.value = numSamples;
updateDescriptionFields();
}
function loadHistory() {
const history = JSON.parse(localStorage.getItem("descriptionHistory") || "[]");
const historyDiv = document.getElementById("history");
historyDiv.innerHTML = history.map(d => `<div>${d}</div>`).join("");
}
function updateHistory(descriptions) {
let history = JSON.parse(localStorage.getItem("descriptionHistory") || "[]");
descriptions.forEach(desc => {
if (desc) {
history = history.filter(d => d !== desc); // Remove if exists
history.unshift(desc); // Add to top
}
});
if (history.length > 100) history = history.slice(0, 100); // Keep last 100
localStorage.setItem("descriptionHistory", JSON.stringify(history));
loadHistory();
}
let isGenerating = false;
async function cancelGeneration() {
if (isGenerating) {
await fetch("/cancel", { method: "POST" });
isGenerating = false;
document.getElementById("progress").style.display = "none";
document.getElementById("cancelButton").style.display = "none";
document.getElementById("results").innerHTML = "Generation cancelled.";
}
}
async function simulateProgress(duration, num_samples) {
const totalSteps = (duration / 30) * 1503 * num_samples; // Estimate total steps
let currentStep = 0;
let speed = 3.0; // 3 in small, 1.5 in medium, 0.75 in large?
const progressBar = document.getElementById("progressBar");
const progressText = document.getElementById("progressText");
while (currentStep < totalSteps && isGenerating) {
//currentStep += totalSteps / 500; // Increment in small steps
currentStep += speed;
const percentage = Math.min((currentStep / totalSteps) * 100, 100);
progressBar.value = percentage;
progressText.textContent = `${Math.round(percentage)}%`;
await new Promise(resolve => setTimeout(resolve, 100)); // Update every 100ms
}
}
document.getElementById("musicForm").onsubmit = async (e) => {
e.preventDefault();
if (isGenerating) {
alert("Generation already in progress!");
return;
}
const formData = new FormData(e.target);
const descriptions = Array.from(formData.getAll("descriptions")).map(d => d.trim()).filter(d => d);
const duration = parseFloat(formData.get("duration"));
const num_samples = parseInt(formData.get("num_samples"));
const model_version = formData.get("model_version");
if (descriptions.length !== num_samples) {
alert("Please provide a description for each sample.");
return;
}
isGenerating = true;
document.getElementById("progress").style.display = "block";
document.getElementById("cancelButton").style.display = "inline-block";
document.getElementById("results").innerHTML = "Generating...";
updateHistory(descriptions);
// Start progress simulation
const progressPromise = simulateProgress(duration, num_samples);
try {
const response = await fetch("/generate", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ descriptions, duration, num_samples, model_version })
});
isGenerating = false;
document.getElementById("progress").style.display = "none";
document.getElementById("cancelButton").style.display = "none";
const result = await response.json();
const resultsDiv = document.getElementById("results");
resultsDiv.innerHTML = "";
result.audio_files.forEach((file, idx) => {
resultsDiv.innerHTML += `
<div class="audio-container">
<h3>Sample ${idx + 1}: ${descriptions[idx]}</h3>
<audio controls src="${file}"></audio>
<a href="${file}" download>Download</a>
</div>`;
});
} catch (error) {
isGenerating = false;
document.getElementById("progress").style.display = "none";
document.getElementById("cancelButton").style.display = "none";
document.getElementById("results").innerHTML = "Error: " + error.message;
}
};
loadHistory(); // Load history on page load
</script>
</body>
</html>
"""
@app.post("/generate")
async def generate_music(request: MusicGenRequest, background_tasks: BackgroundTasks):
global model, model_version, cancel_generation
try:
if len(request.descriptions) != request.num_samples:
raise HTTPException(status_code=400, detail="Number of descriptions must match num_samples")
# Load new model if version changed
if request.model_version != model_version:
print(f"Switching to model {request.model_version}")
model = MusicGen.get_pretrained(request.model_version, device='cuda')
model_version = request.model_version
# Reset cancellation flag
cancel_generation = False
# Set generation parameters
model.set_generation_params(duration=request.duration)
# Generate audio
wav = model.generate(request.descriptions, progress=True)
# Check for cancellation
if cancel_generation:
raise HTTPException(status_code=499, detail="Generation cancelled")
# Create a directory for audio files
os.makedirs("static", exist_ok=True)
# Save audio files and collect URLs
audio_files = []
for idx, one_wav in enumerate(wav):
file_id = str(uuid.uuid4())
file_path = f"static/{file_id}"
audio_write(file_path, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
audio_files.append(f"/{file_path}.wav")
# Schedule cleanup
background_tasks.add_task(cleanup_files, [f"static/{os.path.basename(url)}" for url in audio_files])
return {"audio_files": audio_files}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/cancel")
async def cancel():
global cancel_generation
cancel_generation = True
return {"message": "Cancellation requested"}
Top