diff --git a/src/interpretation.ml b/src/interpretation.ml index b0a68d578b6755b7bf6494dab693d91bba0e6012..e5dd6c9add0e7b76e3a2d31abf3420d0778c74fe 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -253,8 +253,8 @@ let caisar_builtins : caisar_env CRE.built_in_theories list = let filename = Caml.Filename.concat cwd neural_network in let nn = match id_string with - | "NNet" -> NNet (Language.create_nnet_nn env filename) - | "ONNX" -> ONNX (Language.create_onnx_nn env filename) + | "NNet" -> NNet (Language.create_nn_nnet env filename) + | "ONNX" -> ONNX (Language.create_nn_onnx env filename) | _ -> failwith (Fmt.str "Unrecognized neural network format %s" id_string) in diff --git a/src/language.ml b/src/language.ml index d525b1f8602c0e17e9327574c3a16d7e0fc244e3..39afa6c1b39a9ca36d8dabf0a0db9ffee5cbafa3 100644 --- a/src/language.ml +++ b/src/language.ml @@ -186,8 +186,8 @@ let mem_vector = Term.Hls.mem vectors (* -- Classifier *) type nn = { - nn_inputs : int; - nn_outputs : int; + nn_nb_inputs : int; + nn_nb_outputs : int; nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty] nn_filename : string; nn_nier : Onnx.G.t option; [@opaque] @@ -204,7 +204,7 @@ let fresh_nn_ls env name = let id = Ident.id_fresh name in Term.create_fsymbol id [] ty -let create_nnet_nn = +let create_nn_nnet = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in let ty_elt = @@ -219,8 +219,8 @@ let create_nnet_nn = | Error s -> Loc.errorm "%s" s | Ok { n_inputs; n_outputs; _ } -> { - nn_inputs = n_inputs; - nn_outputs = n_outputs; + nn_nb_inputs = n_inputs; + nn_nb_outputs = n_outputs; nn_ty_elt = ty_elt; nn_filename = filename; nn_nier = None; @@ -229,7 +229,7 @@ let create_nnet_nn = Term.Hls.add nets ls nn; ls)) -let create_onnx_nn = +let create_nn_onnx = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in let ty_elt = vector_elt_ty env in @@ -249,8 +249,8 @@ let create_onnx_nn = | Ok nier -> Some nier in { - nn_inputs = n_inputs; - nn_outputs = n_outputs; + nn_nb_inputs = n_inputs; + nn_nb_outputs = n_outputs; nn_ty_elt = ty_elt; nn_filename = filename; nn_nier = nier; diff --git a/src/language.mli b/src/language.mli index 380465783b7c74eef609ec0f715de98c68dad43c..579255e573ea75eb655c249fd11ca95eaadd6cd2 100644 --- a/src/language.mli +++ b/src/language.mli @@ -72,15 +72,15 @@ val mem_vector : Term.lsymbol -> bool (** -- Neural Network *) type nn = private { - nn_inputs : int; - nn_outputs : int; + nn_nb_inputs : int; + nn_nb_outputs : int; nn_ty_elt : Ty.ty; nn_filename : string; nn_nier : Onnx.G.t option; } [@@deriving show] -val create_nnet_nn : Env.env -> string -> Term.lsymbol -val create_onnx_nn : Env.env -> string -> Term.lsymbol +val create_nn_nnet : Env.env -> string -> Term.lsymbol +val create_nn_onnx : Env.env -> string -> Term.lsymbol val lookup_nn : Term.lsymbol -> nn option val mem_nn : Term.lsymbol -> bool diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 8a3816f24ea0371dc16de1d6815f97ccb67fe5c5..41b4568f377988f848a3680977568e0d82edfd74 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -37,8 +37,8 @@ let get_input_variables = [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) when String.equal ls_name.id_string (Ident.op_infix "@@") -> ( match (Language.lookup_nn ls1, Language.lookup_vector ls2) with - | Some { nn_inputs; _ }, Some n -> - assert (nn_inputs = n && n = List.length args); + | Some { nn_nb_inputs; _ }, Some n -> + assert (nn_nb_inputs = n && n = List.length args); List.foldi ~init:acc ~f:add args | _ -> acc) | _ -> Term.t_fold aux acc term