Live classification model training (#18583)

* Implement model training via ZMQ and add model states to represent training

* Get model updates working

* Improve toasts and model state

* Clean up logging

* Add back in
This commit is contained in:
Nicolas Mowen
2025-06-05 09:13:12 -06:00
committed by Blake Blackshear
parent 1c75ff59f1
commit 765a28d812
10 changed files with 168 additions and 22 deletions

View File

@@ -9,12 +9,15 @@
"success": {
"deletedCategory": "Deleted Class",
"deletedImage": "Deleted Images",
"categorizedImage": "Successfully Classified Image"
"categorizedImage": "Successfully Classified Image",
"trainedModel": "Successfully trained model.",
"trainingModel": "Successfully started model training."
},
"error": {
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
"trainingFailed": "Failed to start model training: {{errorMessage}}"
}
},
"deleteCategory": {

View File

@@ -73,7 +73,9 @@ export type ModelState =
| "not_downloaded"
| "downloading"
| "downloaded"
| "error";
| "error"
| "training"
| "complete";
export type EmbeddingsReindexProgressType = {
thumbnails: number;

View File

@@ -45,6 +45,9 @@ import { toast } from "sonner";
import useSWR from "swr";
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
import { TbCategoryPlus } from "react-icons/tb";
import { useModelState } from "@/api/ws";
import { ModelState } from "@/types/ws";
import ActivityIndicator from "@/components/indicators/activity-indicator";
type ModelTrainingViewProps = {
model: CustomClassificationModelConfig;
@@ -54,6 +57,33 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
const [page, setPage] = useState<string>("train");
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
// model state
const [wasTraining, setWasTraining] = useState(false);
const { payload: lastModelState } = useModelState(model.name, true);
const modelState = useMemo<ModelState>(() => {
if (!lastModelState || lastModelState == "downloaded") {
return "complete";
}
return lastModelState;
}, [lastModelState]);
useEffect(() => {
if (!wasTraining) {
return;
}
if (modelState == "complete") {
toast.success(t("toast.success.trainedModel"), {
position: "top-center",
});
setWasTraining(false);
}
// only refresh when modelState changes
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [modelState]);
// dataset
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
@@ -101,8 +131,27 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
// actions
const trainModel = useCallback(() => {
axios.post(`classification/${model.name}/train`);
}, [model]);
axios
.post(`classification/${model.name}/train`)
.then((resp) => {
if (resp.status == 200) {
setWasTraining(true);
toast.success(t("toast.success.trainingModel"), {
position: "top-center",
});
}
})
.catch((error) => {
const errorMessage =
error.response?.data?.message ||
error.response?.data?.detail ||
"Unknown error";
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
position: "top-center",
});
});
}, [model, t]);
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
null,
@@ -274,7 +323,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
</Button>
</div>
) : (
<Button onClick={trainModel}>Train Model</Button>
<Button
className="flex justify-center gap-2"
onClick={trainModel}
disabled={modelState != "complete"}
>
Train Model
{modelState == "training" && <ActivityIndicator size={20} />}
</Button>
)}
</div>
{pageToggle == "train" ? (