From d752c75e416f297a0e1d08586502ad15fc9b4d03 Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Fri, 26 Nov 2021 15:31:34 +0100 Subject: [PATCH] Simplified languages registration --- caisar.opam | 5 ++-- nnet.opam | 13 ++-------- onnx.opam | 8 ------ src/language.ml | 66 +++++++++++++++++------------------------------ src/language.mli | 4 +-- stdlib/caisar.mlw | 2 +- 6 files changed, 31 insertions(+), 67 deletions(-) diff --git a/caisar.opam b/caisar.opam index e3397717..403bae2e 100644 --- a/caisar.opam +++ b/caisar.opam @@ -25,7 +25,7 @@ depends: [ "odoc" {with-doc} ] build: [ - ["dune" "subst" "--root" "."] {dev} + ["dune" "subst"] {dev} [ "dune" "build" @@ -33,8 +33,7 @@ build: [ name "-j" jobs - "--promote-install-files" - "false" + "--promote-install-files=false" "@install" "@runtest" {with-test} "@doc" {with-doc} diff --git a/nnet.opam b/nnet.opam index dbc488d7..aa2ffd8b 100644 --- a/nnet.opam +++ b/nnet.opam @@ -2,14 +2,6 @@ opam-version: "2.0" version: "0.1" synopsis: "NNet parser" -maintainer: ["Michele Alberti" "François Bobot" "Julien Girard-Satabin"] -authors: [ - "Michele Alberti" - "François Bobot" - "Julien Girard-Satabin" - "Zakaria Chihani" -] -bug-reports: "julien.girard2@cea.fr" depends: [ "ocaml" {>= "4.10"} "dune" {>= "2.9" & >= "2.9.1"} @@ -17,7 +9,7 @@ depends: [ "odoc" {with-doc} ] build: [ - ["dune" "subst" "--root" "."] {dev} + ["dune" "subst"] {dev} [ "dune" "build" @@ -25,8 +17,7 @@ build: [ name "-j" jobs - "--promote-install-files" - "false" + "--promote-install-files=false" "@install" "@runtest" {with-test} "@doc" {with-doc} diff --git a/onnx.opam b/onnx.opam index a37a869e..dd555813 100644 --- a/onnx.opam +++ b/onnx.opam @@ -2,14 +2,6 @@ opam-version: "2.0" version: "0.1" synopsis: "ONNX parser" -maintainer: ["Michele Alberti" "François Bobot" "Julien Girard-Satabin"] -authors: [ - "Michele Alberti" - "François Bobot" - "Julien Girard-Satabin" - "Zakaria Chihani" -] -bug-reports: "julien.girard2@cea.fr" depends: [ "ocaml" {>= "4.10"} "dune" {>= "2.9" & >= "2.9.1"} diff --git a/src/language.ml b/src/language.ml index f774f546..1b8a2773 100644 --- a/src/language.ml +++ b/src/language.ml @@ -8,7 +8,7 @@ open Base (* -- Support for the NNet and ONNX neural network format. *) -type net = { +type ioshape = { nb_inputs : int; nb_outputs : int; ty_data : Why3.Ty.ty; @@ -19,13 +19,30 @@ let loaded_nets = Why3.Term.Hls.create 10 let lookup_loaded_nets = Why3.Term.Hls.find_opt loaded_nets +let register_ioshape nb_inputs nb_outputs ioshape_input_type filename th_uc = + let open Why3 in + let ls_net_apply = + let f _ = ioshape_input_type in + Term.create_fsymbol + (Ident.id_fresh "net_apply") + (List.init nb_inputs ~f) + (Ty.ty_tuple (List.init nb_outputs ~f)) + in + Why3.Term.Hls.add loaded_nets ls_net_apply + { filename; nb_inputs; nb_outputs; ty_data = ioshape_input_type }; + let th_uc = + Pmodule.add_pdecl ~vc:false th_uc + (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply)) + in + Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) + let nnet_parser env _ filename _ = let open Why3 in let header = Nnet.parse filename in match header with | Error s -> Loc.errorm "%s" s | Ok header -> - let nnet = Pmodule.read_module env [ "caisar" ] "NNet" in + let nnet = Pmodule.read_module env [ "caisar" ] "IOShape" in let nnet_input_type = Ty.ty_app Theory.(ns_find_ts nnet.mod_theory.th_export [ "input_type" ]) @@ -34,25 +51,8 @@ let nnet_parser env _ filename _ = let id_as_tuple = Ident.id_fresh "AsTuple" in let th_uc = Pmodule.create_module env id_as_tuple in let th_uc = Pmodule.use_export th_uc nnet in - let ls_nnet_apply = - let f _ = nnet_input_type in - Term.create_fsymbol - (Ident.id_fresh "nnet_apply") - (List.init header.n_inputs ~f) - (Ty.ty_tuple (List.init header.n_outputs ~f)) - in - Why3.Term.Hls.add loaded_nets ls_nnet_apply - { - filename; - nb_inputs = header.n_inputs; - nb_outputs = header.n_outputs; - ty_data = nnet_input_type; - }; - let th_uc = - Pmodule.add_pdecl ~vc:false th_uc - (Pdecl.create_pure_decl (Decl.create_param_decl ls_nnet_apply)) - in - Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) + register_ioshape header.n_inputs header.n_outputs nnet_input_type filename + th_uc let onnx_parser env _ filename _ = let open Why3 in @@ -60,7 +60,7 @@ let onnx_parser env _ filename _ = match header with | Error s -> Loc.errorm "%s" s | Ok model -> - let onnx = Pmodule.read_module env [ "caisar" ] "NNet" in + let onnx = Pmodule.read_module env [ "caisar" ] "IOShape" in let onnx_input_type = Ty.ty_app Theory.(ns_find_ts onnx.mod_theory.th_export [ "input_type" ]) @@ -96,26 +96,8 @@ let onnx_parser env _ filename _ = in let input_flat_dim = flattened_dim input_shape in let output_flat_dim = flattened_dim output_shape in - let ls_onnx_apply = - (* TODO: find out input and output size for ONNX. *) - let f _ = onnx_input_type in - Term.create_fsymbol - (Ident.id_fresh "onnx_apply") - (List.init input_flat_dim ~f) - (Ty.ty_tuple (List.init output_flat_dim ~f)) - in - Why3.Term.Hls.add loaded_nets ls_onnx_apply - { - filename; - nb_inputs = input_flat_dim; - nb_outputs = output_flat_dim; - ty_data = onnx_input_type; - }; - let th_uc = - Pmodule.add_pdecl ~vc:false th_uc - (Pdecl.create_pure_decl (Decl.create_param_decl ls_onnx_apply)) - in - Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) + register_ioshape input_flat_dim output_flat_dim onnx_input_type filename + th_uc let register_nnet_support () = Why3.( diff --git a/src/language.mli b/src/language.mli index 33606475..f1d1fcab 100644 --- a/src/language.mli +++ b/src/language.mli @@ -4,14 +4,14 @@ (* *) (**************************************************************************) -type net = { +type ioshape = { nb_inputs : int; nb_outputs : int; ty_data : Why3.Ty.ty; filename : string; } -val lookup_loaded_nets : Why3.Term.lsymbol -> net option +val lookup_loaded_nets : Why3.Term.lsymbol -> ioshape option (** @return the filename of a nnet Why3 representation. *) val register_nnet_support : unit -> unit diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index c2958677..9166a2a9 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -1,4 +1,4 @@ -theory NNet +theory IOShape use ieee_float.Float64 type input_type = t end -- GitLab