diff --git a/src/verification.ml b/src/verification.ml index a0ed2f5a3de229708ffdb9591da41d32124da38d..b38e5c7fc68b37cc34b0ad8b6f93f0ef90b5094f 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -8,6 +8,7 @@ open Base module Filename = Caml.Filename let () = Language.register_nnet_support () + let () = Language.register_onnx_support () let create_env loadpath = @@ -29,6 +30,19 @@ let create_env loadpath = let nnet_or_onnx = Re.compile (Re.str "%{nnet-onnx}") +let combine_prover_answers answers = + let open Why3 in + List.fold_left answers ~init:Call_provers.Valid ~f:(fun r1 l2 -> + let r2 = + List.fold_left l2 ~init:Call_provers.Invalid ~f:(fun r1 r2 -> + match (r1, r2) with + | Call_provers.Valid, _ | _, Call_provers.Valid -> Call_provers.Valid + | _ -> r2) + in + match (r1, r2) with + | Call_provers.Valid, r | r, Call_provers.Valid -> r + | _ -> r1) + let call_prover ~limit (prover : Why3.Whyconf.config_prover) driver task = let open Why3 in let task_prepared = Driver.prepare_task driver task in @@ -36,10 +50,7 @@ let call_prover ~limit (prover : Why3.Whyconf.config_prover) driver task = if String.equal prover.prover.prover_name "Marabou" then let conjs = Trans.apply Split_goal.split_goal_full task_prepared in - let disjs = - List.map ~f:(Trans.apply Split_disjunction.split_disjunction) conjs - in - disjs + List.map ~f:(Trans.apply Split_disjunction.split_disjunction) conjs else [ [ task_prepared ] ] in let command = Whyconf.get_complete_command ~with_steps:false prover in @@ -50,26 +61,15 @@ let call_prover ~limit (prover : Why3.Whyconf.config_prover) driver task = | None -> invalid_arg (Fmt.str "No neural network model found in task") in let command = Re.replace_string nnet_or_onnx ~by:nn_file command in - let call_task task_prepared = + let call_prover_on_task task_prepared = let prover_call = Driver.prove_task_prepared ~command ~limit driver task_prepared in let prover_result = Call_provers.wait_on_call prover_call in prover_result.pr_answer in - let results = List.map tasks ~f:(List.map ~f:call_task) in - let answer = - List.fold_left results ~init:Call_provers.Valid ~f:(fun r1 l2 -> - let r2 = - List.fold_left l2 ~init:Call_provers.Invalid ~f:(fun r1 r2 -> - match (r1, r2) with - | Call_provers.Valid, _ | _, Call_provers.Valid -> Call_provers.Valid - | _ -> r2) - in - match (r1, r2) with - | Call_provers.Valid, r | r, Call_provers.Valid -> r - | _ -> r1) - in + let answers = List.map tasks ~f:(List.map ~f:call_prover_on_task) in + let answer = combine_prover_answers answers in Fmt.pr "Goal %a: %a@." Pretty.print_pr (Task.task_goal task) Call_provers.print_prover_answer answer