~scriptsaifacebookaudiocraftdemos
3 itemsDownload ./*

..
ana-custom.py
anamusicapi3.py
grokmusicgenapi.txt


demosanamusicapi3.py
12 KB• 8•  1 week ago•  DownloadRawClose
1 week ago•  8

{}
from fastapi import FastAPI, Form, HTTPException, BackgroundTasks
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from einops import rearrange
import torch
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import os
import uuid
import asyncio
import threading

app = FastAPI()

# Mount static directory for audio files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Initialize model (default to small)
model_version = 'facebook/musicgen-small'
print("Loading model", model_version)
model = MusicGen.get_pretrained(model_version, device='cuda')

# Global flag for cancellation
cancel_generation = False

class MusicGenRequest(BaseModel):
    descriptions: list[str]
    duration: float
    num_samples: int
    model_version: str

def cleanup_files(file_paths: list[str]):
    import time
    time.sleep(3600)  # Delete after 1 hour
    for path in file_paths:
        if os.path.exists(path):
            os.remove(path)

@app.get("/", response_class=HTMLResponse)
async def get_form():
    return """
    <!DOCTYPE html>
    <html>
    <head>
        <title>MusicGen Web Interface</title>
        <style>
            body { font-family: Arial, sans-serif; margin: 20px; }
            .audio-container { margin: 20px 0; }
            #descriptionInputs { margin-bottom: 20px; }
            .description-field { margin-bottom: 10px; }
            #progressBar { width: 100%; height: 20px; }
            #history { max-height: 200px; overflow-y: auto; border: 1px solid #ccc; padding: 10px; }
            .sample-controls { display: flex; align-items: center; gap: 10px; }
        </style>
    </head>
    <body>
        <h1>MusicGen Audio Generator</h1>
        <form id="musicForm">
            <label>Model:</label>
            <select name="model_version">
                <option value="facebook/musicgen-small">Small</option>
                <option value="facebook/musicgen-medium">Medium</option>
                <option value="facebook/musicgen-large">Large</option>
            </select><br>
            <label>Number of Samples:</label>
            <div class="sample-controls">
                <button type="button" onclick="changeSampleCount(-1)">-</button>
                <span id="num_samples_display">1</span>
                <input type="hidden" id="num_samples" name="num_samples" value="1">
                <button type="button" onclick="changeSampleCount(1)">+</button>
            </div><br>
            <div id="descriptionInputs">
                <div class="description-field">
                    <label>Description 1:</label><br>
                    <textarea name="descriptions" rows="3" cols="50"></textarea>
                </div>
            </div>
            <label>Duration (seconds):</label>
            <input type="number" name="duration" value="30" min="1" max="300"><br>
            <button type="submit">Generate Music</button>
            <button type="button" id="cancelButton" style="display: none;" onclick="cancelGeneration()">Cancel</button>
        </form>
        <div id="progress" style="display: none;">
            <label>Generation Progress:</label>
            <progress id="progressBar" value="0" max="100"></progress>
            <span id="progressText">0%</span>
        </div>
        <div id="results"></div>
        <h2>History (Last 100 Descriptions)</h2>
        <div id="history"></div>
        <script>
            function updateDescriptionFields() {
                const numSamples = parseInt(document.getElementById("num_samples").value);
                document.getElementById("num_samples_display").textContent = numSamples;
                const container = document.getElementById("descriptionInputs");
                container.innerHTML = "";
                for (let i = 0; i < numSamples; i++) {
                    container.innerHTML += `
                        <div class="description-field">
                            <label>Description ${i + 1}:</label><br>
                            <textarea name="descriptions" rows="3" cols="50"></textarea>
                        </div>`;
                }
            }

            function changeSampleCount(delta) {
                const input = document.getElementById("num_samples");
                let numSamples = parseInt(input.value) + delta;
                numSamples = Math.max(1, Math.min(10, numSamples)); // Clamp between 1 and 10
                input.value = numSamples;
                updateDescriptionFields();
            }

            function loadHistory() {
                const history = JSON.parse(localStorage.getItem("descriptionHistory") || "[]");
                const historyDiv = document.getElementById("history");
                historyDiv.innerHTML = history.map(d => `<div>${d}</div>`).join("");
            }

            function updateHistory(descriptions) {
                let history = JSON.parse(localStorage.getItem("descriptionHistory") || "[]");
                descriptions.forEach(desc => {
                    if (desc) {
                        history = history.filter(d => d !== desc); // Remove if exists
                        history.unshift(desc); // Add to top
                    }
                });
                if (history.length > 100) history = history.slice(0, 100); // Keep last 100
                localStorage.setItem("descriptionHistory", JSON.stringify(history));
                loadHistory();
            }

            let isGenerating = false;
            async function cancelGeneration() {
                if (isGenerating) {
                    await fetch("/cancel", { method: "POST" });
                    isGenerating = false;
                    document.getElementById("progress").style.display = "none";
                    document.getElementById("cancelButton").style.display = "none";
                    document.getElementById("results").innerHTML = "Generation cancelled.";
                }
            }

            async function simulateProgress(duration, num_samples) {
                const totalSteps = (duration / 30) * 1503 * num_samples; // Estimate total steps
                let currentStep = 0;
                let speed = 3.0; // 3 in small, 1.5 in medium, 0.75 in large?
                const progressBar = document.getElementById("progressBar");
                const progressText = document.getElementById("progressText");
                while (currentStep < totalSteps && isGenerating) {
                    //currentStep += totalSteps / 500; // Increment in small steps
                    currentStep += speed;
                    const percentage = Math.min((currentStep / totalSteps) * 100, 100);
                    progressBar.value = percentage;
                    progressText.textContent = `${Math.round(percentage)}%`;
                    await new Promise(resolve => setTimeout(resolve, 100)); // Update every 100ms
                }
            }

            document.getElementById("musicForm").onsubmit = async (e) => {
                e.preventDefault();
                if (isGenerating) {
                    alert("Generation already in progress!");
                    return;
                }
                const formData = new FormData(e.target);
                const descriptions = Array.from(formData.getAll("descriptions")).map(d => d.trim()).filter(d => d);
                const duration = parseFloat(formData.get("duration"));
                const num_samples = parseInt(formData.get("num_samples"));
                const model_version = formData.get("model_version");
                if (descriptions.length !== num_samples) {
                    alert("Please provide a description for each sample.");
                    return;
                }
                isGenerating = true;
                document.getElementById("progress").style.display = "block";
                document.getElementById("cancelButton").style.display = "inline-block";
                document.getElementById("results").innerHTML = "Generating...";
                updateHistory(descriptions);
                
                // Start progress simulation
                const progressPromise = simulateProgress(duration, num_samples);
                
                try {
                    const response = await fetch("/generate", {
                        method: "POST",
                        headers: { "Content-Type": "application/json" },
                        body: JSON.stringify({ descriptions, duration, num_samples, model_version })
                    });
                    isGenerating = false;
                    document.getElementById("progress").style.display = "none";
                    document.getElementById("cancelButton").style.display = "none";
                    const result = await response.json();
                    const resultsDiv = document.getElementById("results");
                    resultsDiv.innerHTML = "";
                    result.audio_files.forEach((file, idx) => {
                        resultsDiv.innerHTML += `
                            <div class="audio-container">
                                <h3>Sample ${idx + 1}: ${descriptions[idx]}</h3>
                                <audio controls src="${file}"></audio>
                                <a href="${file}" download>Download</a>
                            </div>`;
                    });
                } catch (error) {
                    isGenerating = false;
                    document.getElementById("progress").style.display = "none";
                    document.getElementById("cancelButton").style.display = "none";
                    document.getElementById("results").innerHTML = "Error: " + error.message;
                }
            };

            loadHistory(); // Load history on page load
        </script>
    </body>
    </html>
    """

@app.post("/generate")
async def generate_music(request: MusicGenRequest, background_tasks: BackgroundTasks):
    global model, model_version, cancel_generation
    try:
        if len(request.descriptions) != request.num_samples:
            raise HTTPException(status_code=400, detail="Number of descriptions must match num_samples")
        
        # Load new model if version changed
        if request.model_version != model_version:
            print(f"Switching to model {request.model_version}")
            model = MusicGen.get_pretrained(request.model_version, device='cuda')
            model_version = request.model_version
        
        # Reset cancellation flag
        cancel_generation = False
        
        # Set generation parameters
        model.set_generation_params(duration=request.duration)
        
        # Generate audio
        wav = model.generate(request.descriptions, progress=True)
        
        # Check for cancellation
        if cancel_generation:
            raise HTTPException(status_code=499, detail="Generation cancelled")
        
        # Create a directory for audio files
        os.makedirs("static", exist_ok=True)
        
        # Save audio files and collect URLs
        audio_files = []
        for idx, one_wav in enumerate(wav):
            file_id = str(uuid.uuid4())
            file_path = f"static/{file_id}"
            audio_write(file_path, one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
            audio_files.append(f"/{file_path}.wav")
        
        # Schedule cleanup
        background_tasks.add_task(cleanup_files, [f"static/{os.path.basename(url)}" for url in audio_files])
        
        return {"audio_files": audio_files}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/cancel")
async def cancel():
    global cancel_generation
    cancel_generation = True
    return {"message": "Cancellation requested"}

Top
©twily.info 2013 - 2025
twily at twily dot info



2 335 240 visits
... ^ v