From e4fd53c328ef59f062d5a55c89db3883d9151dce Mon Sep 17 00:00:00 2001 From: Michele Alberti <michele.alberti@cea.fr> Date: Thu, 25 May 2023 16:10:50 +0200 Subject: [PATCH] [language] Rework neural networks interface. --- src/interpretation.ml | 4 ++-- src/language.ml | 16 ++++++++-------- src/language.mli | 8 ++++---- src/transformations/native_nn_prover.ml | 4 ++-- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/interpretation.ml b/src/interpretation.ml index b0a68d5..e5dd6c9 100644 --- a/src/interpretation.ml +++ b/src/interpretation.ml @@ -253,8 +253,8 @@ let caisar_builtins : caisar_env CRE.built_in_theories list = let filename = Caml.Filename.concat cwd neural_network in let nn = match id_string with - | "NNet" -> NNet (Language.create_nnet_nn env filename) - | "ONNX" -> ONNX (Language.create_onnx_nn env filename) + | "NNet" -> NNet (Language.create_nn_nnet env filename) + | "ONNX" -> ONNX (Language.create_nn_onnx env filename) | _ -> failwith (Fmt.str "Unrecognized neural network format %s" id_string) in diff --git a/src/language.ml b/src/language.ml index d525b1f..39afa6c 100644 --- a/src/language.ml +++ b/src/language.ml @@ -186,8 +186,8 @@ let mem_vector = Term.Hls.mem vectors (* -- Classifier *) type nn = { - nn_inputs : int; - nn_outputs : int; + nn_nb_inputs : int; + nn_nb_outputs : int; nn_ty_elt : Ty.ty; [@printer fun fmt ty -> Fmt.pf fmt "%a" Pretty.print_ty ty] nn_filename : string; nn_nier : Onnx.G.t option; [@opaque] @@ -204,7 +204,7 @@ let fresh_nn_ls env name = let id = Ident.id_fresh name in Term.create_fsymbol id [] ty -let create_nnet_nn = +let create_nn_nnet = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in let ty_elt = @@ -219,8 +219,8 @@ let create_nnet_nn = | Error s -> Loc.errorm "%s" s | Ok { n_inputs; n_outputs; _ } -> { - nn_inputs = n_inputs; - nn_outputs = n_outputs; + nn_nb_inputs = n_inputs; + nn_nb_outputs = n_outputs; nn_ty_elt = ty_elt; nn_filename = filename; nn_nier = None; @@ -229,7 +229,7 @@ let create_nnet_nn = Term.Hls.add nets ls nn; ls)) -let create_onnx_nn = +let create_nn_onnx = Env.Wenv.memoize 13 (fun env -> let h = Hashtbl.create (module String) in let ty_elt = vector_elt_ty env in @@ -249,8 +249,8 @@ let create_onnx_nn = | Ok nier -> Some nier in { - nn_inputs = n_inputs; - nn_outputs = n_outputs; + nn_nb_inputs = n_inputs; + nn_nb_outputs = n_outputs; nn_ty_elt = ty_elt; nn_filename = filename; nn_nier = nier; diff --git a/src/language.mli b/src/language.mli index 3804657..579255e 100644 --- a/src/language.mli +++ b/src/language.mli @@ -72,15 +72,15 @@ val mem_vector : Term.lsymbol -> bool (** -- Neural Network *) type nn = private { - nn_inputs : int; - nn_outputs : int; + nn_nb_inputs : int; + nn_nb_outputs : int; nn_ty_elt : Ty.ty; nn_filename : string; nn_nier : Onnx.G.t option; } [@@deriving show] -val create_nnet_nn : Env.env -> string -> Term.lsymbol -val create_onnx_nn : Env.env -> string -> Term.lsymbol +val create_nn_nnet : Env.env -> string -> Term.lsymbol +val create_nn_onnx : Env.env -> string -> Term.lsymbol val lookup_nn : Term.lsymbol -> nn option val mem_nn : Term.lsymbol -> bool diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 8a3816f..41b4568 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -37,8 +37,8 @@ let get_input_variables = [ { t_node = Tapp (ls1, _); _ }; { t_node = Tapp (ls2, args); _ } ] ) when String.equal ls_name.id_string (Ident.op_infix "@@") -> ( match (Language.lookup_nn ls1, Language.lookup_vector ls2) with - | Some { nn_inputs; _ }, Some n -> - assert (nn_inputs = n && n = List.length args); + | Some { nn_nb_inputs; _ }, Some n -> + assert (nn_nb_inputs = n && n = List.length args); List.foldi ~init:acc ~f:add args | _ -> acc) | _ -> Term.t_fold aux acc term -- GitLab