diff --git a/caisar.opam b/caisar.opam index e3397717f6757d1b74cd46a4264e62c5762d3ca5..403bae2efa898c4683b5ceddc3f3f86c883e61ab 100644 --- a/caisar.opam +++ b/caisar.opam @@ -25,7 +25,7 @@ depends: [ "odoc" {with-doc} ] build: [ - ["dune" "subst" "--root" "."] {dev} + ["dune" "subst"] {dev} [ "dune" "build" @@ -33,8 +33,7 @@ build: [ name "-j" jobs - "--promote-install-files" - "false" + "--promote-install-files=false" "@install" "@runtest" {with-test} "@doc" {with-doc} diff --git a/nnet.opam b/nnet.opam index dbc488d7ebd966482f56e88b31b7cb41a4d09c41..aa2ffd8b08b83e5d09134e0904ee1293c9d9959a 100644 --- a/nnet.opam +++ b/nnet.opam @@ -2,14 +2,6 @@ opam-version: "2.0" version: "0.1" synopsis: "NNet parser" -maintainer: ["Michele Alberti" "François Bobot" "Julien Girard-Satabin"] -authors: [ - "Michele Alberti" - "François Bobot" - "Julien Girard-Satabin" - "Zakaria Chihani" -] -bug-reports: "julien.girard2@cea.fr" depends: [ "ocaml" {>= "4.10"} "dune" {>= "2.9" & >= "2.9.1"} @@ -17,7 +9,7 @@ depends: [ "odoc" {with-doc} ] build: [ - ["dune" "subst" "--root" "."] {dev} + ["dune" "subst"] {dev} [ "dune" "build" @@ -25,8 +17,7 @@ build: [ name "-j" jobs - "--promote-install-files" - "false" + "--promote-install-files=false" "@install" "@runtest" {with-test} "@doc" {with-doc} diff --git a/onnx.opam b/onnx.opam index a37a869ed2dca7dcdc97c6707cfcf7e753c5b20b..dd5558135befdf02b5def8c6aae20d3f34470fb0 100644 --- a/onnx.opam +++ b/onnx.opam @@ -2,14 +2,6 @@ opam-version: "2.0" version: "0.1" synopsis: "ONNX parser" -maintainer: ["Michele Alberti" "François Bobot" "Julien Girard-Satabin"] -authors: [ - "Michele Alberti" - "François Bobot" - "Julien Girard-Satabin" - "Zakaria Chihani" -] -bug-reports: "julien.girard2@cea.fr" depends: [ "ocaml" {>= "4.10"} "dune" {>= "2.9" & >= "2.9.1"} diff --git a/src/language.ml b/src/language.ml index f774f546003a913a74b8235728b83d9a018fb571..1b8a2773f450bc0feba571807646e41e6d64dd40 100644 --- a/src/language.ml +++ b/src/language.ml @@ -8,7 +8,7 @@ open Base (* -- Support for the NNet and ONNX neural network format. *) -type net = { +type ioshape = { nb_inputs : int; nb_outputs : int; ty_data : Why3.Ty.ty; @@ -19,13 +19,30 @@ let loaded_nets = Why3.Term.Hls.create 10 let lookup_loaded_nets = Why3.Term.Hls.find_opt loaded_nets +let register_ioshape nb_inputs nb_outputs ioshape_input_type filename th_uc = + let open Why3 in + let ls_net_apply = + let f _ = ioshape_input_type in + Term.create_fsymbol + (Ident.id_fresh "net_apply") + (List.init nb_inputs ~f) + (Ty.ty_tuple (List.init nb_outputs ~f)) + in + Why3.Term.Hls.add loaded_nets ls_net_apply + { filename; nb_inputs; nb_outputs; ty_data = ioshape_input_type }; + let th_uc = + Pmodule.add_pdecl ~vc:false th_uc + (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply)) + in + Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) + let nnet_parser env _ filename _ = let open Why3 in let header = Nnet.parse filename in match header with | Error s -> Loc.errorm "%s" s | Ok header -> - let nnet = Pmodule.read_module env [ "caisar" ] "NNet" in + let nnet = Pmodule.read_module env [ "caisar" ] "IOShape" in let nnet_input_type = Ty.ty_app Theory.(ns_find_ts nnet.mod_theory.th_export [ "input_type" ]) @@ -34,25 +51,8 @@ let nnet_parser env _ filename _ = let id_as_tuple = Ident.id_fresh "AsTuple" in let th_uc = Pmodule.create_module env id_as_tuple in let th_uc = Pmodule.use_export th_uc nnet in - let ls_nnet_apply = - let f _ = nnet_input_type in - Term.create_fsymbol - (Ident.id_fresh "nnet_apply") - (List.init header.n_inputs ~f) - (Ty.ty_tuple (List.init header.n_outputs ~f)) - in - Why3.Term.Hls.add loaded_nets ls_nnet_apply - { - filename; - nb_inputs = header.n_inputs; - nb_outputs = header.n_outputs; - ty_data = nnet_input_type; - }; - let th_uc = - Pmodule.add_pdecl ~vc:false th_uc - (Pdecl.create_pure_decl (Decl.create_param_decl ls_nnet_apply)) - in - Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) + register_ioshape header.n_inputs header.n_outputs nnet_input_type filename + th_uc let onnx_parser env _ filename _ = let open Why3 in @@ -60,7 +60,7 @@ let onnx_parser env _ filename _ = match header with | Error s -> Loc.errorm "%s" s | Ok model -> - let onnx = Pmodule.read_module env [ "caisar" ] "NNet" in + let onnx = Pmodule.read_module env [ "caisar" ] "IOShape" in let onnx_input_type = Ty.ty_app Theory.(ns_find_ts onnx.mod_theory.th_export [ "input_type" ]) @@ -96,26 +96,8 @@ let onnx_parser env _ filename _ = in let input_flat_dim = flattened_dim input_shape in let output_flat_dim = flattened_dim output_shape in - let ls_onnx_apply = - (* TODO: find out input and output size for ONNX. *) - let f _ = onnx_input_type in - Term.create_fsymbol - (Ident.id_fresh "onnx_apply") - (List.init input_flat_dim ~f) - (Ty.ty_tuple (List.init output_flat_dim ~f)) - in - Why3.Term.Hls.add loaded_nets ls_onnx_apply - { - filename; - nb_inputs = input_flat_dim; - nb_outputs = output_flat_dim; - ty_data = onnx_input_type; - }; - let th_uc = - Pmodule.add_pdecl ~vc:false th_uc - (Pdecl.create_pure_decl (Decl.create_param_decl ls_onnx_apply)) - in - Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc) + register_ioshape input_flat_dim output_flat_dim onnx_input_type filename + th_uc let register_nnet_support () = Why3.( diff --git a/src/language.mli b/src/language.mli index 33606475a1d214098bc012851295a93b69bf9b7e..f1d1fcabeff3011ebe2235a8ebac2dd91fbf3dd4 100644 --- a/src/language.mli +++ b/src/language.mli @@ -4,14 +4,14 @@ (* *) (**************************************************************************) -type net = { +type ioshape = { nb_inputs : int; nb_outputs : int; ty_data : Why3.Ty.ty; filename : string; } -val lookup_loaded_nets : Why3.Term.lsymbol -> net option +val lookup_loaded_nets : Why3.Term.lsymbol -> ioshape option (** @return the filename of a nnet Why3 representation. *) val register_nnet_support : unit -> unit diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index c29586772de968684847674cfe0347514ebc3204..9166a2a923317ff4bd192e7a27bba880dd625628 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -1,4 +1,4 @@ -theory NNet +theory IOShape use ieee_float.Float64 type input_type = t end