mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-09-26 19:41:29 +08:00
Live classification model training (#18583)
* Implement model training via ZMQ and add model states to represent training * Get model updates working * Improve toasts and model state * Clean up logging * Add back in
This commit is contained in:

committed by
Blake Blackshear

parent
1c75ff59f1
commit
765a28d812
@@ -9,12 +9,15 @@
|
||||
"success": {
|
||||
"deletedCategory": "Deleted Class",
|
||||
"deletedImage": "Deleted Images",
|
||||
"categorizedImage": "Successfully Classified Image"
|
||||
"categorizedImage": "Successfully Classified Image",
|
||||
"trainedModel": "Successfully trained model.",
|
||||
"trainingModel": "Successfully started model training."
|
||||
},
|
||||
"error": {
|
||||
"deleteImageFailed": "Failed to delete: {{errorMessage}}",
|
||||
"deleteCategoryFailed": "Failed to delete class: {{errorMessage}}",
|
||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}"
|
||||
"categorizeFailed": "Failed to categorize image: {{errorMessage}}",
|
||||
"trainingFailed": "Failed to start model training: {{errorMessage}}"
|
||||
}
|
||||
},
|
||||
"deleteCategory": {
|
||||
|
@@ -73,7 +73,9 @@ export type ModelState =
|
||||
| "not_downloaded"
|
||||
| "downloading"
|
||||
| "downloaded"
|
||||
| "error";
|
||||
| "error"
|
||||
| "training"
|
||||
| "complete";
|
||||
|
||||
export type EmbeddingsReindexProgressType = {
|
||||
thumbnails: number;
|
||||
|
@@ -45,6 +45,9 @@ import { toast } from "sonner";
|
||||
import useSWR from "swr";
|
||||
import ClassificationSelectionDialog from "@/components/overlay/ClassificationSelectionDialog";
|
||||
import { TbCategoryPlus } from "react-icons/tb";
|
||||
import { useModelState } from "@/api/ws";
|
||||
import { ModelState } from "@/types/ws";
|
||||
import ActivityIndicator from "@/components/indicators/activity-indicator";
|
||||
|
||||
type ModelTrainingViewProps = {
|
||||
model: CustomClassificationModelConfig;
|
||||
@@ -54,6 +57,33 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
const [page, setPage] = useState<string>("train");
|
||||
const [pageToggle, setPageToggle] = useOptimisticState(page, setPage, 100);
|
||||
|
||||
// model state
|
||||
|
||||
const [wasTraining, setWasTraining] = useState(false);
|
||||
const { payload: lastModelState } = useModelState(model.name, true);
|
||||
const modelState = useMemo<ModelState>(() => {
|
||||
if (!lastModelState || lastModelState == "downloaded") {
|
||||
return "complete";
|
||||
}
|
||||
|
||||
return lastModelState;
|
||||
}, [lastModelState]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!wasTraining) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (modelState == "complete") {
|
||||
toast.success(t("toast.success.trainedModel"), {
|
||||
position: "top-center",
|
||||
});
|
||||
setWasTraining(false);
|
||||
}
|
||||
// only refresh when modelState changes
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [modelState]);
|
||||
|
||||
// dataset
|
||||
|
||||
const { data: trainImages, mutate: refreshTrain } = useSWR<string[]>(
|
||||
@@ -101,8 +131,27 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
// actions
|
||||
|
||||
const trainModel = useCallback(() => {
|
||||
axios.post(`classification/${model.name}/train`);
|
||||
}, [model]);
|
||||
axios
|
||||
.post(`classification/${model.name}/train`)
|
||||
.then((resp) => {
|
||||
if (resp.status == 200) {
|
||||
setWasTraining(true);
|
||||
toast.success(t("toast.success.trainingModel"), {
|
||||
position: "top-center",
|
||||
});
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
const errorMessage =
|
||||
error.response?.data?.message ||
|
||||
error.response?.data?.detail ||
|
||||
"Unknown error";
|
||||
|
||||
toast.error(t("toast.error.trainingFailed", { errorMessage }), {
|
||||
position: "top-center",
|
||||
});
|
||||
});
|
||||
}, [model, t]);
|
||||
|
||||
const [deleteDialogOpen, setDeleteDialogOpen] = useState<string[] | null>(
|
||||
null,
|
||||
@@ -274,7 +323,14 @@ export default function ModelTrainingView({ model }: ModelTrainingViewProps) {
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<Button onClick={trainModel}>Train Model</Button>
|
||||
<Button
|
||||
className="flex justify-center gap-2"
|
||||
onClick={trainModel}
|
||||
disabled={modelState != "complete"}
|
||||
>
|
||||
Train Model
|
||||
{modelState == "training" && <ActivityIndicator size={20} />}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{pageToggle == "train" ? (
|
||||
|
Reference in New Issue
Block a user