diff --git a/studio/src/hooks/usePendingRepaintJobs.ts b/studio/src/hooks/usePendingRepaintJobs.ts index 1c7687d5..2346ba95 100644 --- a/studio/src/hooks/usePendingRepaintJobs.ts +++ b/studio/src/hooks/usePendingRepaintJobs.ts @@ -49,10 +49,10 @@ interface JobResultResponse { error: string | null; } -async function fetchActiveJobs(): Promise { +async function fetchActiveJobs(signal?: AbortSignal): Promise { const [queuedRes, runningRes] = await Promise.all([ - fetch('/api/jobs?status=queued&limit=50'), - fetch('/api/jobs?status=running&limit=50'), + fetch('/api/jobs?status=queued&limit=50', { signal }), + fetch('/api/jobs?status=running&limit=50', { signal }), ]); const [queued, running] = await Promise.all([ queuedRes.ok ? (queuedRes.json() as Promise) : Promise.resolve({ jobs: [], total: 0 }), @@ -110,10 +110,17 @@ export function usePendingRepaintJobs( useEffect(() => { let destroyed = false; + let polling = false; + let abortController: AbortController | null = null; async function poll(): Promise { + if (polling || destroyed) return; // Prevent overlapping polls + polling = true; + abortController = new AbortController(); + const timeoutId = setTimeout(() => abortController?.abort(), 8000); + try { - const jobs = await fetchActiveJobs(); + const jobs = await fetchActiveJobs(abortController.signal); if (destroyed) return; setActiveJobs( @@ -130,14 +137,20 @@ export function usePendingRepaintJobs( if (jobs.length === 0) stopPolling(); } catch { - // Network error — keep polling + // Network/timeout error — keep polling + } finally { + clearTimeout(timeoutId); + polling = false; + abortController = null; } } - // Also check recently completed jobs (last 30s) that haven't been added + // Also check recently completed jobs that haven't been added async function checkCompleted(): Promise { + const ac = new AbortController(); + const timeoutId = setTimeout(() => ac.abort(), 8000); try { - const res = await fetch('/api/jobs?status=completed&limit=50'); + const res = await fetch('/api/jobs?status=completed&limit=50', { signal: ac.signal }); if (!res.ok || destroyed) return; const data = (await res.json()) as JobListResponse; const unadded = data.jobs.filter((j) => !addedRef.current.has(j.jobId)); @@ -163,6 +176,8 @@ export function usePendingRepaintJobs( ); } catch { // Ignore + } finally { + clearTimeout(timeoutId); } } @@ -176,6 +191,7 @@ export function usePendingRepaintJobs( return () => { destroyed = true; + abortController?.abort(); stopPolling(); }; }, [stopPolling]); diff --git a/studio/src/hooks/useRepaintBatch.ts b/studio/src/hooks/useRepaintBatch.ts index 5561812c..7d75040d 100644 --- a/studio/src/hooks/useRepaintBatch.ts +++ b/studio/src/hooks/useRepaintBatch.ts @@ -105,6 +105,7 @@ export function useRepaintBatch(): { slots: SlotState[]; isRunning: boolean; run: (req: RepaintRequest, count: number) => void; + clear: () => void; } { const [slots, setSlots] = useState([]); const pendingRef = useRef([]); @@ -224,9 +225,17 @@ export function useRepaintBatch(): { [stopPolling, startPolling], ); + const clear = useCallback(() => { + stopPolling(); + pendingRef.current = []; + savePending([]); + setSlots([]); + }, [stopPolling]); + return { slots, isRunning: slots.some((s) => s.status === 'pending'), run, + clear, }; } diff --git a/studio/src/hooks/useShootBatch.ts b/studio/src/hooks/useShootBatch.ts new file mode 100644 index 00000000..863276b5 --- /dev/null +++ b/studio/src/hooks/useShootBatch.ts @@ -0,0 +1,255 @@ +import { useCallback, useEffect, useRef, useState } from 'react'; +import type { GeneratedImage, MaturityRating, ModelId } from '../types'; + +export interface ShootRequest { + prompt: string; + negativePrompt?: string; + model: ModelId; + layout: string; + steps: number; + guidanceScale: number; + identityId?: string; + identityStrength: number; + ipAdapterScale: number; + bodyImageOverride?: string; + bodyIpAdapterScale: number; + faceImageOverride?: string; + maturityRating: MaturityRating; +} + +export type SlotState = + | { status: 'pending' } + | { status: 'done'; image: GeneratedImage } + | { status: 'error'; error: Error }; + +interface PendingJob { + jobId: string; + slotIndex: number; + prompt: string; + model: ModelId; + maturityRating: MaturityRating; + seed: number; +} + +const STORAGE_KEY = 'shoot:pending-jobs'; +const POLL_INTERVAL_MS = 3000; + +function loadPending(): PendingJob[] { + try { + return JSON.parse(localStorage.getItem(STORAGE_KEY) ?? '[]') as PendingJob[]; + } catch { + return []; + } +} + +function savePending(jobs: PendingJob[]): void { + if (jobs.length === 0) { + localStorage.removeItem(STORAGE_KEY); + } else { + localStorage.setItem(STORAGE_KEY, JSON.stringify(jobs)); + } +} + +async function submitShootJob(req: ShootRequest, seed: number): Promise { + const response = await fetch('/api/generate/async', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + prompt: req.prompt, + negativePrompt: req.negativePrompt, + model: req.model, + layout: req.layout, + steps: req.steps, + guidanceScale: req.guidanceScale, + seed, + maturityRating: req.maturityRating, + identityId: req.identityId, + identityStrength: req.identityStrength, + ipAdapterScale: req.ipAdapterScale, + bodyImageOverride: req.bodyImageOverride, + bodyIpAdapterScale: req.bodyIpAdapterScale, + faceImageOverride: req.faceImageOverride, + enableModeration: false, + }), + }); + + const data = await response.json() as { success: boolean; jobId?: string; error?: string; detail?: string }; + if (!response.ok || !data.success || !data.jobId) { + throw new Error(data.detail ?? data.error ?? `HTTP ${response.status}`); + } + return data.jobId; +} + +async function pollJob(jobId: string): Promise<{ status: string; error?: string }> { + const response = await fetch(`/api/jobs/${jobId}`); + if (!response.ok) throw new Error(`Poll failed: HTTP ${response.status}`); + return response.json() as Promise<{ status: string; error?: string }>; +} + +async function fetchResult(job: PendingJob): Promise { + const response = await fetch(`/api/jobs/${job.jobId}/result`); + if (!response.ok) throw new Error(`Result fetch failed: HTTP ${response.status}`); + const data = await response.json() as { + status: string; + result?: { output_base64?: string; total_duration_ms?: number; quality_score?: number }; + error?: string; + }; + if (!data.result?.output_base64) throw new Error(data.error ?? 'No image in result'); + return { + id: crypto.randomUUID(), + imageBase64: data.result.output_base64, + prompt: job.prompt, + model: job.model, + maturityRating: job.maturityRating, + durationMs: data.result.total_duration_ms, + qualityScore: data.result.quality_score, + createdAt: new Date().toISOString(), + }; +} + +/** + * Async-job shoot batch. + * + * Each call to `run()` submits N jobs to /generate/async with body/face references, + * persists job IDs in localStorage, and polls until all slots resolve. + * On page refresh, pending slots are restored and polling resumes. + */ +export function useShootBatch(): { + slots: SlotState[]; + isRunning: boolean; + run: (req: ShootRequest, count: number, startSeed?: number) => void; + clear: () => void; +} { + const [slots, setSlots] = useState([]); + const pendingRef = useRef([]); + const intervalRef = useRef | null>(null); + + const stopPolling = useCallback(() => { + if (intervalRef.current !== null) { + clearInterval(intervalRef.current); + intervalRef.current = null; + } + }, []); + + const startPolling = useCallback(() => { + stopPolling(); + intervalRef.current = setInterval(() => { + void (async () => { + const jobs = [...pendingRef.current]; + if (jobs.length === 0) { + stopPolling(); + return; + } + + await Promise.allSettled( + jobs.map(async (job) => { + try { + const { status, error } = await pollJob(job.jobId); + + if (status === 'completed') { + try { + const image = await fetchResult(job); + setSlots((prev) => + prev.map((s, i): SlotState => (i === job.slotIndex ? { status: 'done', image } : s)), + ); + } catch (e) { + const err = e instanceof Error ? e : new Error(String(e)); + setSlots((prev) => + prev.map((s, i): SlotState => (i === job.slotIndex ? { status: 'error', error: err } : s)), + ); + } + pendingRef.current = pendingRef.current.filter((j) => j.jobId !== job.jobId); + savePending(pendingRef.current); + } else if (status === 'failed') { + const err = new Error(error ?? 'Job failed'); + setSlots((prev) => + prev.map((s, i): SlotState => (i === job.slotIndex ? { status: 'error', error: err } : s)), + ); + pendingRef.current = pendingRef.current.filter((j) => j.jobId !== job.jobId); + savePending(pendingRef.current); + } + } catch { + // Network error — keep polling + } + }), + ); + })(); + }, POLL_INTERVAL_MS); + }, [stopPolling]); + + // Restore pending jobs on mount + useEffect(() => { + const pending = loadPending(); + if (pending.length === 0) return stopPolling; + + const maxIndex = pending.reduce((max, j) => Math.max(max, j.slotIndex), 0); + setSlots(Array.from({ length: maxIndex + 1 }, () => ({ status: 'pending' as const }))); + pendingRef.current = pending; + startPolling(); + + return stopPolling; + }, []); // eslint-disable-line react-hooks/exhaustive-deps + + const run = useCallback( + (req: ShootRequest, count: number, startSeed?: number) => { + stopPolling(); + pendingRef.current = []; + savePending([]); + + const seeds = Array.from({ length: count }, (_, i) => + startSeed !== undefined ? startSeed + i : Math.floor(Math.random() * 2 ** 31), + ); + + setSlots(seeds.map(() => ({ status: 'pending' as const }))); + + void (async () => { + const results = await Promise.allSettled( + seeds.map((seed, slotIndex) => + submitShootJob(req, seed).then((jobId) => ({ jobId, slotIndex, seed })), + ), + ); + + const jobs: PendingJob[] = []; + results.forEach((result, slotIndex) => { + if (result.status === 'fulfilled') { + jobs.push({ + jobId: result.value.jobId, + slotIndex, + prompt: req.prompt, + model: req.model, + maturityRating: req.maturityRating, + seed: result.value.seed, + }); + } else { + const err = + result.reason instanceof Error ? result.reason : new Error(String(result.reason)); + setSlots((prev) => + prev.map((s, i): SlotState => (i === slotIndex ? { status: 'error', error: err } : s)), + ); + } + }); + + if (jobs.length > 0) { + pendingRef.current = jobs; + savePending(jobs); + startPolling(); + } + })(); + }, + [stopPolling, startPolling], + ); + + const clear = useCallback(() => { + stopPolling(); + pendingRef.current = []; + savePending([]); + setSlots([]); + }, [stopPolling]); + + return { + slots, + isRunning: slots.some((s) => s.status === 'pending'), + run, + clear, + }; +}