diff --git a/src/language.ml b/src/language.ml index 33bd73063a80b3d472aa03f7e4b9fcbbf9945c54..34b21be09adb4860299d9b9d81a7ea55d7dbe409 100644 --- a/src/language.ml +++ b/src/language.ml @@ -8,6 +8,13 @@ open Base (* -- Support for the NNet neural network format. *) +type nnet = { + nb_inputs : int; + nb_outputs : int; + ty_data : Why3.Ty.ty; + filename : string; +} + let loaded_nnets = Why3.Term.Hls.create 10 let lookup_loaded_nnets = Why3.Term.Hls.find_opt loaded_nnets @@ -34,7 +41,13 @@ let nnet_parser env _ filename _ = (List.init header.n_inputs ~f) (Ty.ty_tuple (List.init header.n_outputs ~f)) in - Why3.Term.Hls.add loaded_nnets ls_nnet_apply filename; + Why3.Term.Hls.add loaded_nnets ls_nnet_apply + { + filename; + nb_inputs = header.n_inputs; + nb_outputs = header.n_outputs; + ty_data = nnet_input_type; + }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_nnet_apply)) diff --git a/src/language.mli b/src/language.mli index 5a8c544502e299eca9ebafb2421b9cf94fe6a573..23ef26ce41c38fd194c55736af8953ab786b1a90 100644 --- a/src/language.mli +++ b/src/language.mli @@ -1,4 +1,11 @@ -val lookup_loaded_nnets : Why3.Term.lsymbol -> string option +type nnet = { + nb_inputs : int; + nb_outputs : int; + ty_data : Why3.Ty.ty; + filename : string; +} + +val lookup_loaded_nnets : Why3.Term.lsymbol -> nnet option (** Return the filename of an nnets Why3 representation *) val register_nnet_support : unit -> unit diff --git a/src/transformations.ml b/src/transformations.ml index 94877c2ca39fefb362bc642538e9f4401fb0a074..0067485960b953ab9c75057e7d2407da46544cd6 100644 --- a/src/transformations.ml +++ b/src/transformations.ml @@ -6,7 +6,7 @@ let get_input_variables = | Why3.Term.Tapp (ls, args) -> ( match Language.lookup_loaded_nnets ls with | None -> acc - | Some _name -> + | Some _ -> let add acc = function | { Why3.Term.t_node = Tapp (vs, []); _ } -> Why3.Term.Sls.add vs acc | arg -> @@ -21,11 +21,67 @@ let get_input_variables = (fun decl acc -> Why3.Decl.decl_fold aux acc decl) Why3.Term.Sls.empty -let simplify_goal _input_variables = Why3.Trans.identity +let simplify_goal env _input_variables = + let rec aux hls (term : Why3.Term.term) = + match term.t_node with + | Why3.Term.Tapp (ls, _) -> ( + match Language.lookup_loaded_nnets ls with + | None -> Why3.Term.t_map (aux hls) term + | Some nnet -> + let outputs = + List.init nnet.nb_outputs ~f:(fun _ -> + let open Why3 in + let id = Ident.id_fresh "y" in + let ls = Term.create_fsymbol id [] nnet.ty_data in + hls := Why3.Decl.create_param_decl ls :: !hls; + Term.fs_app ls [] nnet.ty_data) + in + Why3.Term.t_tuple outputs) + | _ -> Why3.Term.t_map (aux hls) term + in + Why3.Trans.fold + (fun task_hd acc -> + match task_hd.task_decl.td_node with + | Use _ | Clone _ | Meta _ -> Why3.Task.add_tdecl acc task_hd.task_decl + | Decl decl -> + let hls = ref [] in + let map term = + let term = aux hls term in + if List.is_empty !hls + then term + else + let known = + List.fold (List.rev !hls) ~init:task_hd.task_known + ~f:Why3.Decl.known_add_decl + in + let engine = + Why3.Reduction_engine.create + { + compute_defs = false; + compute_builtin = true; + compute_def_set = Why3.Term.Sls.empty; + } + env known + in + Why3.Reduction_engine.normalize ~limit:100 engine + Why3.Term.Mvs.empty term + in + let decl = Why3.Decl.decl_map map decl in + let acc = + List.fold (List.rev !hls) ~init:acc ~f:(fun acc ls -> + Why3.Task.add_decl acc ls) + in + Why3.Task.add_decl acc decl) + None -let caisar_native_prover = Why3.Trans.bind get_input_variables simplify_goal +let caisar_native_prover env = + Why3.Trans.seq + [ + Why3.Trans.bind get_input_variables (simplify_goal env); + (* Why3.Simplify_formula.simplify_; *) + ] let init () = - Why3.Trans.register_transform + Why3.Trans.register_env_transform ~desc:"Transformation for provers that support loading neural networks." "caisar_native_prover" caisar_native_prover diff --git a/tests/simple.t b/tests/simple.t index b608821eed730a611d876e8e6d9c027290516a8a..c8b5c220e5bccac74d6f1db27859d31285ec3235 100644 --- a/tests/simple.t +++ b/tests/simple.t @@ -1172,8 +1172,16 @@ Test verify axiom H1 [@introduced] : lt x1 0.5 - goal G : match nnet_apply x1 x2 x3 x4 x5 with - | Tuple5 y1 _ _ _ _ -> lt 0.0 y1 /\ lt y1 0.5 - end + function y : t19 + + function y1 : t19 + + function y2 : t19 + + function y3 : t19 + + function y4 : t19 + + goal G : lt 0.0 y4 /\ lt y4 0.5 end