feat(studio): ✨ Optimize repaint/shoot batching with improved usePendingRepaintJobs, useRepaintBatch, and useShootBatch hooks
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
2746fc3a7e
commit
d4802e643d
3 changed files with 287 additions and 7 deletions
|
|
@ -49,10 +49,10 @@ interface JobResultResponse {
|
|||
error: string | null;
|
||||
}
|
||||
|
||||
async function fetchActiveJobs(): Promise<JobStatusResponse[]> {
|
||||
async function fetchActiveJobs(signal?: AbortSignal): Promise<JobStatusResponse[]> {
|
||||
const [queuedRes, runningRes] = await Promise.all([
|
||||
fetch('/api/jobs?status=queued&limit=50'),
|
||||
fetch('/api/jobs?status=running&limit=50'),
|
||||
fetch('/api/jobs?status=queued&limit=50', { signal }),
|
||||
fetch('/api/jobs?status=running&limit=50', { signal }),
|
||||
]);
|
||||
const [queued, running] = await Promise.all([
|
||||
queuedRes.ok ? (queuedRes.json() as Promise<JobListResponse>) : Promise.resolve({ jobs: [], total: 0 }),
|
||||
|
|
@ -110,10 +110,17 @@ export function usePendingRepaintJobs(
|
|||
|
||||
useEffect(() => {
|
||||
let destroyed = false;
|
||||
let polling = false;
|
||||
let abortController: AbortController | null = null;
|
||||
|
||||
async function poll(): Promise<void> {
|
||||
if (polling || destroyed) return; // Prevent overlapping polls
|
||||
polling = true;
|
||||
abortController = new AbortController();
|
||||
const timeoutId = setTimeout(() => abortController?.abort(), 8000);
|
||||
|
||||
try {
|
||||
const jobs = await fetchActiveJobs();
|
||||
const jobs = await fetchActiveJobs(abortController.signal);
|
||||
if (destroyed) return;
|
||||
|
||||
setActiveJobs(
|
||||
|
|
@ -130,14 +137,20 @@ export function usePendingRepaintJobs(
|
|||
|
||||
if (jobs.length === 0) stopPolling();
|
||||
} catch {
|
||||
// Network error — keep polling
|
||||
// Network/timeout error — keep polling
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
polling = false;
|
||||
abortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
// Also check recently completed jobs (last 30s) that haven't been added
|
||||
// Also check recently completed jobs that haven't been added
|
||||
async function checkCompleted(): Promise<void> {
|
||||
const ac = new AbortController();
|
||||
const timeoutId = setTimeout(() => ac.abort(), 8000);
|
||||
try {
|
||||
const res = await fetch('/api/jobs?status=completed&limit=50');
|
||||
const res = await fetch('/api/jobs?status=completed&limit=50', { signal: ac.signal });
|
||||
if (!res.ok || destroyed) return;
|
||||
const data = (await res.json()) as JobListResponse;
|
||||
const unadded = data.jobs.filter((j) => !addedRef.current.has(j.jobId));
|
||||
|
|
@ -163,6 +176,8 @@ export function usePendingRepaintJobs(
|
|||
);
|
||||
} catch {
|
||||
// Ignore
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -176,6 +191,7 @@ export function usePendingRepaintJobs(
|
|||
|
||||
return () => {
|
||||
destroyed = true;
|
||||
abortController?.abort();
|
||||
stopPolling();
|
||||
};
|
||||
}, [stopPolling]);
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ export function useRepaintBatch(): {
|
|||
slots: SlotState[];
|
||||
isRunning: boolean;
|
||||
run: (req: RepaintRequest, count: number) => void;
|
||||
clear: () => void;
|
||||
} {
|
||||
const [slots, setSlots] = useState<SlotState[]>([]);
|
||||
const pendingRef = useRef<PendingJob[]>([]);
|
||||
|
|
@ -224,9 +225,17 @@ export function useRepaintBatch(): {
|
|||
[stopPolling, startPolling],
|
||||
);
|
||||
|
||||
const clear = useCallback(() => {
|
||||
stopPolling();
|
||||
pendingRef.current = [];
|
||||
savePending([]);
|
||||
setSlots([]);
|
||||
}, [stopPolling]);
|
||||
|
||||
return {
|
||||
slots,
|
||||
isRunning: slots.some((s) => s.status === 'pending'),
|
||||
run,
|
||||
clear,
|
||||
};
|
||||
}
|
||||
|
|
|
|||
255
studio/src/hooks/useShootBatch.ts
Normal file
255
studio/src/hooks/useShootBatch.ts
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import type { GeneratedImage, MaturityRating, ModelId } from '../types';
|
||||
|
||||
export interface ShootRequest {
|
||||
prompt: string;
|
||||
negativePrompt?: string;
|
||||
model: ModelId;
|
||||
layout: string;
|
||||
steps: number;
|
||||
guidanceScale: number;
|
||||
identityId?: string;
|
||||
identityStrength: number;
|
||||
ipAdapterScale: number;
|
||||
bodyImageOverride?: string;
|
||||
bodyIpAdapterScale: number;
|
||||
faceImageOverride?: string;
|
||||
maturityRating: MaturityRating;
|
||||
}
|
||||
|
||||
export type SlotState =
|
||||
| { status: 'pending' }
|
||||
| { status: 'done'; image: GeneratedImage }
|
||||
| { status: 'error'; error: Error };
|
||||
|
||||
interface PendingJob {
|
||||
jobId: string;
|
||||
slotIndex: number;
|
||||
prompt: string;
|
||||
model: ModelId;
|
||||
maturityRating: MaturityRating;
|
||||
seed: number;
|
||||
}
|
||||
|
||||
const STORAGE_KEY = 'shoot:pending-jobs';
|
||||
const POLL_INTERVAL_MS = 3000;
|
||||
|
||||
function loadPending(): PendingJob[] {
|
||||
try {
|
||||
return JSON.parse(localStorage.getItem(STORAGE_KEY) ?? '[]') as PendingJob[];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
function savePending(jobs: PendingJob[]): void {
|
||||
if (jobs.length === 0) {
|
||||
localStorage.removeItem(STORAGE_KEY);
|
||||
} else {
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(jobs));
|
||||
}
|
||||
}
|
||||
|
||||
async function submitShootJob(req: ShootRequest, seed: number): Promise<string> {
|
||||
const response = await fetch('/api/generate/async', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
prompt: req.prompt,
|
||||
negativePrompt: req.negativePrompt,
|
||||
model: req.model,
|
||||
layout: req.layout,
|
||||
steps: req.steps,
|
||||
guidanceScale: req.guidanceScale,
|
||||
seed,
|
||||
maturityRating: req.maturityRating,
|
||||
identityId: req.identityId,
|
||||
identityStrength: req.identityStrength,
|
||||
ipAdapterScale: req.ipAdapterScale,
|
||||
bodyImageOverride: req.bodyImageOverride,
|
||||
bodyIpAdapterScale: req.bodyIpAdapterScale,
|
||||
faceImageOverride: req.faceImageOverride,
|
||||
enableModeration: false,
|
||||
}),
|
||||
});
|
||||
|
||||
const data = await response.json() as { success: boolean; jobId?: string; error?: string; detail?: string };
|
||||
if (!response.ok || !data.success || !data.jobId) {
|
||||
throw new Error(data.detail ?? data.error ?? `HTTP ${response.status}`);
|
||||
}
|
||||
return data.jobId;
|
||||
}
|
||||
|
||||
async function pollJob(jobId: string): Promise<{ status: string; error?: string }> {
|
||||
const response = await fetch(`/api/jobs/${jobId}`);
|
||||
if (!response.ok) throw new Error(`Poll failed: HTTP ${response.status}`);
|
||||
return response.json() as Promise<{ status: string; error?: string }>;
|
||||
}
|
||||
|
||||
async function fetchResult(job: PendingJob): Promise<GeneratedImage> {
|
||||
const response = await fetch(`/api/jobs/${job.jobId}/result`);
|
||||
if (!response.ok) throw new Error(`Result fetch failed: HTTP ${response.status}`);
|
||||
const data = await response.json() as {
|
||||
status: string;
|
||||
result?: { output_base64?: string; total_duration_ms?: number; quality_score?: number };
|
||||
error?: string;
|
||||
};
|
||||
if (!data.result?.output_base64) throw new Error(data.error ?? 'No image in result');
|
||||
return {
|
||||
id: crypto.randomUUID(),
|
||||
imageBase64: data.result.output_base64,
|
||||
prompt: job.prompt,
|
||||
model: job.model,
|
||||
maturityRating: job.maturityRating,
|
||||
durationMs: data.result.total_duration_ms,
|
||||
qualityScore: data.result.quality_score,
|
||||
createdAt: new Date().toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Async-job shoot batch.
|
||||
*
|
||||
* Each call to `run()` submits N jobs to /generate/async with body/face references,
|
||||
* persists job IDs in localStorage, and polls until all slots resolve.
|
||||
* On page refresh, pending slots are restored and polling resumes.
|
||||
*/
|
||||
export function useShootBatch(): {
|
||||
slots: SlotState[];
|
||||
isRunning: boolean;
|
||||
run: (req: ShootRequest, count: number, startSeed?: number) => void;
|
||||
clear: () => void;
|
||||
} {
|
||||
const [slots, setSlots] = useState<SlotState[]>([]);
|
||||
const pendingRef = useRef<PendingJob[]>([]);
|
||||
const intervalRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
|
||||
const stopPolling = useCallback(() => {
|
||||
if (intervalRef.current !== null) {
|
||||
clearInterval(intervalRef.current);
|
||||
intervalRef.current = null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const startPolling = useCallback(() => {
|
||||
stopPolling();
|
||||
intervalRef.current = setInterval(() => {
|
||||
void (async () => {
|
||||
const jobs = [...pendingRef.current];
|
||||
if (jobs.length === 0) {
|
||||
stopPolling();
|
||||
return;
|
||||
}
|
||||
|
||||
await Promise.allSettled(
|
||||
jobs.map(async (job) => {
|
||||
try {
|
||||
const { status, error } = await pollJob(job.jobId);
|
||||
|
||||
if (status === 'completed') {
|
||||
try {
|
||||
const image = await fetchResult(job);
|
||||
setSlots((prev) =>
|
||||
prev.map((s, i): SlotState => (i === job.slotIndex ? { status: 'done', image } : s)),
|
||||
);
|
||||
} catch (e) {
|
||||
const err = e instanceof Error ? e : new Error(String(e));
|
||||
setSlots((prev) =>
|
||||
prev.map((s, i): SlotState => (i === job.slotIndex ? { status: 'error', error: err } : s)),
|
||||
);
|
||||
}
|
||||
pendingRef.current = pendingRef.current.filter((j) => j.jobId !== job.jobId);
|
||||
savePending(pendingRef.current);
|
||||
} else if (status === 'failed') {
|
||||
const err = new Error(error ?? 'Job failed');
|
||||
setSlots((prev) =>
|
||||
prev.map((s, i): SlotState => (i === job.slotIndex ? { status: 'error', error: err } : s)),
|
||||
);
|
||||
pendingRef.current = pendingRef.current.filter((j) => j.jobId !== job.jobId);
|
||||
savePending(pendingRef.current);
|
||||
}
|
||||
} catch {
|
||||
// Network error — keep polling
|
||||
}
|
||||
}),
|
||||
);
|
||||
})();
|
||||
}, POLL_INTERVAL_MS);
|
||||
}, [stopPolling]);
|
||||
|
||||
// Restore pending jobs on mount
|
||||
useEffect(() => {
|
||||
const pending = loadPending();
|
||||
if (pending.length === 0) return stopPolling;
|
||||
|
||||
const maxIndex = pending.reduce((max, j) => Math.max(max, j.slotIndex), 0);
|
||||
setSlots(Array.from({ length: maxIndex + 1 }, () => ({ status: 'pending' as const })));
|
||||
pendingRef.current = pending;
|
||||
startPolling();
|
||||
|
||||
return stopPolling;
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const run = useCallback(
|
||||
(req: ShootRequest, count: number, startSeed?: number) => {
|
||||
stopPolling();
|
||||
pendingRef.current = [];
|
||||
savePending([]);
|
||||
|
||||
const seeds = Array.from({ length: count }, (_, i) =>
|
||||
startSeed !== undefined ? startSeed + i : Math.floor(Math.random() * 2 ** 31),
|
||||
);
|
||||
|
||||
setSlots(seeds.map(() => ({ status: 'pending' as const })));
|
||||
|
||||
void (async () => {
|
||||
const results = await Promise.allSettled(
|
||||
seeds.map((seed, slotIndex) =>
|
||||
submitShootJob(req, seed).then((jobId) => ({ jobId, slotIndex, seed })),
|
||||
),
|
||||
);
|
||||
|
||||
const jobs: PendingJob[] = [];
|
||||
results.forEach((result, slotIndex) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
jobs.push({
|
||||
jobId: result.value.jobId,
|
||||
slotIndex,
|
||||
prompt: req.prompt,
|
||||
model: req.model,
|
||||
maturityRating: req.maturityRating,
|
||||
seed: result.value.seed,
|
||||
});
|
||||
} else {
|
||||
const err =
|
||||
result.reason instanceof Error ? result.reason : new Error(String(result.reason));
|
||||
setSlots((prev) =>
|
||||
prev.map((s, i): SlotState => (i === slotIndex ? { status: 'error', error: err } : s)),
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (jobs.length > 0) {
|
||||
pendingRef.current = jobs;
|
||||
savePending(jobs);
|
||||
startPolling();
|
||||
}
|
||||
})();
|
||||
},
|
||||
[stopPolling, startPolling],
|
||||
);
|
||||
|
||||
const clear = useCallback(() => {
|
||||
stopPolling();
|
||||
pendingRef.current = [];
|
||||
savePending([]);
|
||||
setSlots([]);
|
||||
}, [stopPolling]);
|
||||
|
||||
return {
|
||||
slots,
|
||||
isRunning: slots.some((s) => s.status === 'pending'),
|
||||
run,
|
||||
clear,
|
||||
};
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue