diff --git a/src/proof_strategy.ml b/src/proof_strategy.ml index ee690421141a1a2f910ef1772862bd122737ece2..1acd424808fd1da9efbd4eb22708b341d87a19d1 100644 --- a/src/proof_strategy.ml +++ b/src/proof_strategy.ml @@ -50,14 +50,11 @@ let apply_classic_prover env task = let trans = Nn2smt.trans env in do_apply_prover ~lookup ~trans [ task ] -let apply_native_nn_prover env task = +let apply_native_nn_prover task = let lookup = Language.lookup_nn in let trans = Trans.seq - [ - Introduction.introduce_premises; - Native_nn_prover.trans_nn_application env; - ] + [ Introduction.introduce_premises; Native_nn_prover.trans_nn_application ] in let tasks = Trans.apply Split_goal.split_goal_full task in do_apply_prover ~lookup ~trans tasks diff --git a/src/proof_strategy.mli b/src/proof_strategy.mli index 1ec1f823061f8bf766216d777309cc68601a5888..3d9717be700498eaec69960f83488e035d407d0f 100644 --- a/src/proof_strategy.mli +++ b/src/proof_strategy.mli @@ -25,5 +25,5 @@ open Why3 val apply_classic_prover : Env.env -> Task.task -> Task.task list (** Detect and translate applications of neural networks into SMT-LIB. *) -val apply_native_nn_prover : Env.env -> Task.task -> Task.task list +val apply_native_nn_prover : Task.task -> Task.task list (** Detect and execute applications of neural networks. *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 41b4568f377988f848a3680977568e0d82edfd74..b28e8e8a40c398712859db1ea3b9d15534306d2b 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -23,115 +23,131 @@ open Why3 open Base -let get_input_variables = - let add i acc = function - | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls i acc - | arg -> - invalid_arg - (Fmt.str "No direct variable in application: %a" Pretty.print_term arg) +let collect_input_vars = + let hls = Term.Hls.create 13 in + let add index mls = function + | { Term.t_node = Tapp (ls, []); _ } -> Term.Mls.add ls index mls + | t -> failwith (Fmt.str "No input variable: %a" Pretty.print_term t) in - let rec aux acc (term : Term.term) = + let rec do_collect mls (term : Term.term) = match term.t_node with | Term.Tapp - ( { ls_name; _ }, - [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) - when String.equal ls_name.id_string (Ident.op_infix "@@") -> ( - match (Language.lookup_nn ls1, Language.lookup_vector ls2) with - | Some { nn_nb_inputs; _ }, Some n -> - assert (nn_nb_inputs = n && n = List.length args); - List.foldi ~init:acc ~f:add args - | _ -> acc) - | _ -> Term.t_fold aux acc term + ( ls1 (* @@ *), + [ + { t_node = Tapp (ls2 (* nn *), _); _ }; + { t_node = Tapp (ls3 (* input vector *), tl (* input vars *)); _ }; + ] ) + when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> ( + match (Language.lookup_nn ls2, Language.lookup_vector ls3) with + | Some { nn_nb_inputs; _ }, Some vector_length -> ( + assert (nn_nb_inputs = vector_length && vector_length = List.length tl); + match Term.Hls.find_opt hls ls3 with + | None -> List.foldi ~init:mls ~f:add tl + | Some _ -> mls) + | _, _ -> mls) + | _ -> Term.t_fold do_collect mls term + in + Trans.fold_decl + (fun decl mls -> Decl.decl_fold do_collect mls decl) + Term.Mls.empty + +let create_output_vars = + let rec do_create mt (term : Term.term) = + match term.t_node with + | Term.Tapp (ls1 (* @@ *), [ { t_node = Tapp (ls2 (* nn *), _); _ }; _ ]) + when String.equal ls1.ls_name.id_string (Ident.op_infix "@@") -> ( + match Language.lookup_nn ls2 with + | Some { nn_nb_outputs; nn_ty_elt; _ } -> ( + match Term.Mterm.find_opt term mt with + | None -> + let output_vars = + List.init nn_nb_outputs ~f:(fun index -> + (index, Term.create_fsymbol (Ident.id_fresh "y") [] nn_ty_elt)) + in + Term.Mterm.add term output_vars mt + | Some _ -> mt) + | _ -> mt) + | _ -> Term.t_fold do_create mt term in - Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty + Trans.fold_decl + (fun decl mt -> Decl.decl_fold do_create mt decl) + Term.Mterm.empty -(* Create logic symbols for output variables and simplify the formula. *) -let simplify_goal _env input_variables = - let rec aux hls (term : Term.term) = +let simplify_nn_application input_vars output_vars = + let rec do_simplify nn_filenames (term : Term.term) = match term.t_node with | Term.Tapp - ( ls_vget, + ( ls1 (* [_] *), [ ({ t_node = Tapp - ( ls_apply_nn, + ( ls2 (* @@ *), [ - { t_node = Tapp (ls_nn, _); _ }; - { t_node = Tapp (ls_vector, _); _ }; + { t_node = Tapp (ls3 (* nn *), _); _ }; + _ (* input vector *); ] ); _; - } as _t1); - ({ t_node = Tconst (ConstInt i); _ } as _t2); + } as t1); + ({ t_node = Tconst (ConstInt index); _ } as _t2); ] ) - when String.equal ls_vget.ls_name.id_string (Ident.op_get "") - && String.equal ls_apply_nn.ls_name.id_string (Ident.op_infix "@@") - -> ( - match (Language.lookup_nn ls_nn, Language.lookup_vector ls_vector) with - | Some nn, Some _ -> - let index = Number.to_small_integer i in - let hout = - Hashtbl.update_and_return hls nn.nn_filename ~f:(fun hout -> - let create_ls_output () = - let id = Ident.id_fresh "y" in - Term.create_fsymbol id [] nn.nn_ty_elt - in - match hout with - | None -> - let hout = Hashtbl.create (module Int) in - let ls = create_ls_output () in - Hashtbl.add_exn hout ~key:index ~data:ls; - hout - | Some hout -> - Hashtbl.update hout index ~f:(fun lsout -> - match lsout with - | None -> - let ls = create_ls_output () in - Hashtbl.add_exn hout ~key:index ~data:ls; - ls - | Some ls -> ls); - hout) - in - let ls_output = Hashtbl.find_exn hout index in - Term.fs_app ls_output [] nn.nn_ty_elt - | _ -> Term.t_map (aux hls) term) - | _ -> Term.t_map (aux hls) term + when String.equal ls1.ls_name.id_string (Ident.op_get "") + && String.equal ls2.ls_name.id_string (Ident.op_infix "@@") -> ( + match Term.Mterm.find_opt t1 output_vars with + | None -> Term.t_map (do_simplify nn_filenames) term + | Some output_vars -> + let nn = Option.value_exn (Language.lookup_nn ls3) in + nn_filenames := nn.nn_filename :: !nn_filenames; + let index = Number.to_small_integer index in + assert (index <= nn.Language.nn_nb_outputs); + let ls = Caml.List.assoc index output_vars in + Term.fs_app ls [] nn.nn_ty_elt) + | _ -> Term.t_map (do_simplify nn_filenames) term in - let htbl = Hashtbl.create (module String) in Trans.fold - (fun task_hd acc -> + (fun task_hd task -> match task_hd.task_decl.td_node with - | Use _ | Clone _ | Meta _ -> Task.add_tdecl acc task_hd.task_decl | Decl { d_node = Dparam ls; _ } -> - (* Add the meta first, then the actual input variable declaration. - - This is mandatory for allowing some printers to properly work. Such + (* Add the meta first, then the actual input variable declaration. This + is mandatory for allowing some printers to properly work. Such printers need to collect metas and print the actual declarations if a corresponding meta has been already collected. Hence metas must appear before corresponding declarations. *) - let acc = - match Term.Mls.find_opt ls input_variables with - | None -> acc - | Some pos -> - Task.add_meta acc Utils.meta_input [ MAls ls; MAint pos ] + let task = + match Term.Mls.find_opt ls input_vars with + | None -> task + | Some index -> + Task.add_meta task Utils.meta_input [ MAls ls; MAint index ] + in + Task.add_tdecl task task_hd.task_decl + | Decl ({ d_node = Dprop (Pgoal, _, _); _ } as decl) -> + let nn_filename = ref [] in + let decl = Decl.decl_map (fun t -> do_simplify nn_filename t) decl in + let task = + List.fold !nn_filename ~init:task ~f:(fun task nn_filename -> + Task.add_meta task Utils.meta_nn_filename [ MAstr nn_filename ]) in - Task.add_tdecl acc task_hd.task_decl - | Decl decl -> - let decl = Decl.decl_map (fun term -> aux htbl term) decl in - let acc = - Hashtbl.fold htbl ~init:acc ~f:(fun ~key ~data acc -> - let acc = Task.add_meta acc Utils.meta_nn_filename [ MAstr key ] in - Hashtbl.fold data ~init:acc ~f:(fun ~key ~data acc -> + let task = + Term.Mterm.fold + (fun _t output_vars task -> (* Add the meta first, then the actual output variable declaration. Same reason as for input variable declarations. *) - let acc = - Task.add_meta acc Utils.meta_output [ MAls data; MAint key ] - in - let decl = Decl.create_param_decl data in - Task.add_decl acc decl)) + List.fold output_vars ~init:task + ~f:(fun task (index, output_var) -> + let task = + Task.add_meta task Utils.meta_output + [ MAls output_var; MAint index ] + in + let decl = Decl.create_param_decl output_var in + Task.add_decl task decl)) + output_vars task in - Task.add_decl acc decl) + Task.add_decl task decl + | Use _ | Clone _ | Meta _ | Decl _ -> + Task.add_tdecl task task_hd.task_decl) None -let trans_nn_application env = - Trans.bind get_input_variables (simplify_goal env) +let trans_nn_application = + Trans.bind collect_input_vars (fun input_vars -> + Trans.bind create_output_vars (fun output_vars -> + simplify_nn_application input_vars output_vars)) diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index 82cc9c71c748345fbdaa1ec8f40c6f9be708fc3c..090612fbdd8423b21677a1268c08ce7398cff807 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,4 +20,4 @@ (* *) (**************************************************************************) -val trans_nn_application : Why3.Env.env -> Why3.Task.task Why3.Trans.trans +val trans_nn_application : Why3.Task.task Why3.Trans.trans diff --git a/src/verification.ml b/src/verification.ml index dda279e25450552e858d16965c5abb978f978d90..ce6f2f785397ca3674d272bdd8b631c13fc366a6 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -223,9 +223,9 @@ let answer_dataset limit config env prover config_prover driver dataset task = in (prover_answer, additional_info) -let answer_generic limit config env prover config_prover driver ~proof_strategy - task = - let tasks = proof_strategy env task in +let answer_generic limit config prover config_prover driver ~proof_strategy task + = + let tasks = proof_strategy task in let answers = List.concat_map tasks ~f:(fun task -> let task = Driver.prepare_task driver task in @@ -265,13 +265,13 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset task | Marabou | Pyrat | Nnenum | ABCrown -> let task = Interpretation.interpret_task ~cwd env task in let proof_strategy = Proof_strategy.apply_native_nn_prover in - answer_generic limit config env prover config_prover driver - ~proof_strategy task + answer_generic limit config prover config_prover driver ~proof_strategy + task | CVC5 -> let task = Interpretation.interpret_task ~cwd env task in - let proof_strategy = Proof_strategy.apply_classic_prover in - answer_generic limit config env prover config_prover driver - ~proof_strategy task + let proof_strategy = Proof_strategy.apply_classic_prover env in + answer_generic limit config prover config_prover driver ~proof_strategy + task in let id = Task.task_goal task in { id; prover_answer; additional_info }