diff --git a/src/dataset.ml b/src/dataset.ml index c90c516af8141c53b936a6bfb730226f711ef47d..6ce481846cd19043eea6f79a4265e965fe684d8c 100644 --- a/src/dataset.ml +++ b/src/dataset.ml @@ -52,30 +52,34 @@ let interpret_normalization = (fun decl acc -> match decl.d_node with | Dlogic ldecl -> - List.fold ldecl ~init:acc ~f:(fun acc (_, ls_defn) -> + List.fold ldecl ~init:acc ~f:(fun acc (ls, ls_defn) -> let _, term = Decl.open_ls_defn ls_defn in - match term.t_node with - | Term.Tapp - ( { ls_name = { id_string = "min_max_scale"; _ }; _ }, - [ - { t_node = Tconst (ConstReal rc_min); _ }; - { t_node = Tconst (ConstReal rc_max); _ }; - { t_node = Tapp (_dataset, []); _ }; - ] ) -> - let rc_min = float_of_real_constant rc_min in - let rc_max = float_of_real_constant rc_max in - MinMax (rc_min, rc_max) :: acc - | Term.Tapp - ( { ls_name = { id_string = "z_norm"; _ }; _ }, - [ - { t_node = Tconst (ConstReal mean); _ }; - { t_node = Tconst (ConstReal std_dev); _ }; - { t_node = Tapp (_dataset, []); _ }; - ] ) -> - let mean = float_of_real_constant mean in - let std_dev = float_of_real_constant std_dev in - Znorm (mean, std_dev) :: acc - | _ -> acc) + let normalization = + match term.t_node with + | Term.Tapp + ( { ls_name = { id_string = "min_max_scale"; _ }; _ }, + [ + { t_node = Tconst (ConstReal rc_min); _ }; + { t_node = Tconst (ConstReal rc_max); _ }; + { t_node = Tapp (_dataset, []); _ }; + ] ) -> + let rc_min = float_of_real_constant rc_min in + let rc_max = float_of_real_constant rc_max in + Some (MinMax (rc_min, rc_max)) + | Term.Tapp + ( { ls_name = { id_string = "z_norm"; _ }; _ }, + [ + { t_node = Tconst (ConstReal mean); _ }; + { t_node = Tconst (ConstReal std_dev); _ }; + { t_node = Tapp (_dataset, []); _ }; + ] ) -> + let mean = float_of_real_constant mean in + let std_dev = float_of_real_constant std_dev in + Some (Znorm (mean, std_dev)) + | _ -> None + in + Option.value_map normalization ~default:acc ~f:(fun n -> + (ls, n) :: acc)) | _ -> acc) [] @@ -112,7 +116,7 @@ let failwith_unsupported_ls ls = let interpret_predicate env ~on_model ~on_dataset task = let task = Trans.apply Introduction.introduce_premises task in - let normalization = Trans.apply interpret_normalization task in + let ls_with_normalization = Trans.apply interpret_normalization task in let term = Task.task_goal_fmla task in match term.t_node with | Term.Tapp @@ -137,6 +141,10 @@ let interpret_predicate env ~on_model ~on_dataset task = else failwith_unsupported_ls ls | _ -> failwith_unsupported_term term in + let normalization = + List.filter_map ls_with_normalization ~f:(fun (ls, normalization) -> + if Term.ls_equal ls dataset then Some normalization else None) + in let dataset = on_dataset dataset in let model = on_model model in { model; dataset; normalization; property }