From 9aceca1abd6c04c525e81947608d3e250e7a425c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr> Date: Wed, 30 Jun 2021 10:04:51 +0200 Subject: [PATCH] Start version using Why3 and move out nnet parser. --- .ocamlformat | 0 dune-project | 10 +++ lib/nnet/dune | 5 ++ lib/nnet/nnet.ml | 156 ++++++++++++++++++++++++++++++++++++++++++++++ lib/nnet/nnet.mli | 24 +++++++ nnet.opam | 24 +++++++ src/dune | 7 +++ src/main.ml | 0 standalone/dune | 6 +- 9 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 .ocamlformat create mode 100644 lib/nnet/dune create mode 100644 lib/nnet/nnet.ml create mode 100644 lib/nnet/nnet.mli create mode 100644 nnet.opam create mode 100644 src/dune create mode 100644 src/main.ml diff --git a/.ocamlformat b/.ocamlformat new file mode 100644 index 00000000..e69de29b diff --git a/dune-project b/dune-project index d1653452..76984936 100644 --- a/dune-project +++ b/dune-project @@ -22,3 +22,13 @@ (ppx_deriving_yojson (>= 3.6.1)) ) ) + +(package + (name nnet) + (synopsis "NNet parser") + (depends + (ocaml (>= 4.10)) + (dune (>= 2.7.1)) + (base (>= v0.14.0)) + ) +) diff --git a/lib/nnet/dune b/lib/nnet/dune new file mode 100644 index 00000000..c15dbf76 --- /dev/null +++ b/lib/nnet/dune @@ -0,0 +1,5 @@ +(library + (name nnet) + (public_name nnet) + (libraries base) + (synopsis "NNet parser")) diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml new file mode 100644 index 00000000..be5bf793 --- /dev/null +++ b/lib/nnet/nnet.ml @@ -0,0 +1,156 @@ +(**************************************************************************) +(* *) +(* This file is part of Caisar. *) +(* *) +(**************************************************************************) + +open Base +module Format = Caml.Format +module Sys = Caml.Sys +module Filename = Caml.Filename + +type t = { + n_layers : int; + n_inputs : int; + n_outputs : int; + max_layer_size : int; + layer_sizes : int list; + min_input_values : float list; + max_input_values : float list; + mean_values : float list * float; + range_values : float list * float; +} +[@@deriving show { with_path = false }] + +(* NNet format handling. *) + +let nnet_format_error s = + Error (Format.sprintf "NNet format error: %s condition not satisfied." s) + +let nnet_delimiter = Str.regexp "," + +(* Parse a single NNet format line: split line using [nnet_delimiter] as + delimiter, and convert each string into a number by means of converter [f]. *) +let handle_nnet_line ~f line = + List.filter_map + ~f:(fun s -> try Some (f (String.strip s)) with _ -> None) + (Str.split nnet_delimiter line) + +(* Skip the header part, ie comments, of the NNet format. *) +let handle_nnet_header filename in_channel = + let exception End_of_header in + let pos_in = ref (Stdlib.pos_in in_channel) in + try + while true do + let line = Stdlib.input_line in_channel in + if not (Str.string_match (Str.regexp "//") line 0) then + raise End_of_header + else pos_in := Stdlib.pos_in in_channel + done; + assert false + with + | End_of_header -> + (* At this point we have read one line past the header part: seek back. *) + Stdlib.seek_in in_channel !pos_in; + Ok () + | End_of_file -> + Error (Format.sprintf "NNet model not found in file `%s'." filename) + +(* Retrieve number of layers, inputs, outputs and maximum layer size. *) +let handle_nnet_basic_info in_channel = + try + let line = Stdlib.input_line in_channel in + match handle_nnet_line ~f:Stdlib.int_of_string line with + | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> + Ok (n_layers, n_inputs, n_outputs, max_layer_size) + | _ -> nnet_format_error "second" + with End_of_file -> nnet_format_error "second" + +(* Retrieve size of each layer, including inputs and outputs. *) +let handle_nnet_layer_sizes n_layers in_channel = + try + let line = Stdlib.input_line in_channel in + let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in + if List.length layer_sizes = n_layers + 1 then Ok layer_sizes + else nnet_format_error "third" + with End_of_file -> nnet_format_error "third" + +(* Skip unused flag. *) +let handle_nnet_unused_flag in_channel = + try + let _ = Stdlib.input_line in_channel in + Ok () + with End_of_file -> nnet_format_error "forth" + +(* Retrive minimum values of inputs. *) +let handle_nnet_min_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length min_input_values = n_inputs then Ok min_input_values + else nnet_format_error "fifth" + with End_of_file -> nnet_format_error "fifth" + +(* Retrive maximum values of inputs. *) +let handle_nnet_max_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length max_input_values = n_inputs then Ok max_input_values + else nnet_format_error "sixth" + with End_of_file -> nnet_format_error "sixth" + +(* Retrieve mean values of inputs and one value for all outputs. *) +let handle_nnet_mean_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length mean_values = n_inputs + 1 then + let mean_input_values, mean_output_value = + List.split_n mean_values n_inputs + in + Ok (mean_input_values, List.hd_exn mean_output_value) + else nnet_format_error "seventh" + with End_of_file -> nnet_format_error "seventh" + +(* Retrieve range values of inputs and one value for all outputs. *) +let handle_nnet_range_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length range_values = n_inputs + 1 then + let range_input_values, range_output_value = + List.split_n range_values n_inputs + in + Ok (range_input_values, List.hd_exn range_output_value) + else nnet_format_error "eighth" + with End_of_file -> nnet_format_error "eighth" + +(* Retrieves [filename] NNet model metadata wrt NNet format specification (see + https://github.com/sisl/NNet for details.) *) +let parse_metadata filename = + let open Result in + let in_channel = Stdlib.open_in filename in + try + handle_nnet_header filename in_channel >>= fun () -> + handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> + handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> + handle_nnet_unused_flag in_channel >>= fun () -> + handle_nnet_min_input_values n_is in_channel >>= fun min_input_values -> + handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> + handle_nnet_mean_values n_is in_channel >>= fun mean_values -> + handle_nnet_range_values n_is in_channel >>= fun range_values -> + Stdlib.close_in in_channel; + Ok + { + n_layers = n_ls; + n_inputs = n_is; + n_outputs = n_os; + max_layer_size = max_l_size; + layer_sizes; + min_input_values; + max_input_values; + mean_values; + range_values; + } + with Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli new file mode 100644 index 00000000..b718de82 --- /dev/null +++ b/lib/nnet/nnet.mli @@ -0,0 +1,24 @@ +(**************************************************************************) +(* *) +(* This file is part of Caisar. *) +(* *) +(**************************************************************************) + +type t = private { + n_layers : int; (** Number of layers. *) + n_inputs : int; (** Number of inputs. *) + n_outputs : int; (** Number of outputs. *) + max_layer_size : int; (** Maximum layer size. *) + layer_sizes : int list; (** Size of each layer. *) + min_input_values : float list; (** Minimum values of inputs. *) + max_input_values : float list; (** Maximum values of inputs. *) + mean_values : float list * float; + (** Mean values of inputs and one value for all outputs. *) + range_values : float list * float; + (** Range values of inputs and one value for all outputs. *) +} +[@@deriving show { with_path = false }] +(** NNet model metadata. *) + +val parse_metadata : string -> (t, string) Result.t +(** Parse an NNet file for metadata. *) diff --git a/nnet.opam b/nnet.opam new file mode 100644 index 00000000..af0a8be5 --- /dev/null +++ b/nnet.opam @@ -0,0 +1,24 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +version: "0.1" +synopsis: "NNet parser" +depends: [ + "ocaml" {>= "4.10"} + "dune" {>= "2.7" & >= "2.7.1"} + "base" {>= "v0.14.0"} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] diff --git a/src/dune b/src/dune new file mode 100644 index 00000000..0c6fcbae --- /dev/null +++ b/src/dune @@ -0,0 +1,7 @@ +(executable + (name main) + (public_name caisar) + (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3) + (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) + (package caisar) +) diff --git a/src/main.ml b/src/main.ml new file mode 100644 index 00000000..e69de29b diff --git a/standalone/dune b/standalone/dune index 4ec3915a..7eea4820 100644 --- a/standalone/dune +++ b/standalone/dune @@ -7,7 +7,9 @@ (executable (name main) - (public_name caisar) + (public_name caisar-standalone) (modules_without_implementation property_types) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime) - (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))) + (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) + (package caisar) +) -- GitLab