From 9aceca1abd6c04c525e81947608d3e250e7a425c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20Bobot?= <francois.bobot@cea.fr>
Date: Wed, 30 Jun 2021 10:04:51 +0200
Subject: [PATCH] Start version using Why3 and move out nnet parser.

---
 .ocamlformat      |   0
 dune-project      |  10 +++
 lib/nnet/dune     |   5 ++
 lib/nnet/nnet.ml  | 156 ++++++++++++++++++++++++++++++++++++++++++++++
 lib/nnet/nnet.mli |  24 +++++++
 nnet.opam         |  24 +++++++
 src/dune          |   7 +++
 src/main.ml       |   0
 standalone/dune   |   6 +-
 9 files changed, 230 insertions(+), 2 deletions(-)
 create mode 100644 .ocamlformat
 create mode 100644 lib/nnet/dune
 create mode 100644 lib/nnet/nnet.ml
 create mode 100644 lib/nnet/nnet.mli
 create mode 100644 nnet.opam
 create mode 100644 src/dune
 create mode 100644 src/main.ml

diff --git a/.ocamlformat b/.ocamlformat
new file mode 100644
index 00000000..e69de29b
diff --git a/dune-project b/dune-project
index d1653452..76984936 100644
--- a/dune-project
+++ b/dune-project
@@ -22,3 +22,13 @@
    (ppx_deriving_yojson (>= 3.6.1))
   )
 )
+
+(package
+  (name nnet)
+  (synopsis "NNet parser")
+  (depends
+   (ocaml (>= 4.10))
+   (dune (>= 2.7.1))
+   (base (>= v0.14.0))
+  )
+)
diff --git a/lib/nnet/dune b/lib/nnet/dune
new file mode 100644
index 00000000..c15dbf76
--- /dev/null
+++ b/lib/nnet/dune
@@ -0,0 +1,5 @@
+(library
+ (name        nnet)
+ (public_name nnet)
+ (libraries base)
+ (synopsis "NNet parser"))
diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml
new file mode 100644
index 00000000..be5bf793
--- /dev/null
+++ b/lib/nnet/nnet.ml
@@ -0,0 +1,156 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of Caisar.                                          *)
+(*                                                                        *)
+(**************************************************************************)
+
+open Base
+module Format = Caml.Format
+module Sys = Caml.Sys
+module Filename = Caml.Filename
+
+type t = {
+  n_layers : int;
+  n_inputs : int;
+  n_outputs : int;
+  max_layer_size : int;
+  layer_sizes : int list;
+  min_input_values : float list;
+  max_input_values : float list;
+  mean_values : float list * float;
+  range_values : float list * float;
+}
+[@@deriving show { with_path = false }]
+
+(* NNet format handling. *)
+
+let nnet_format_error s =
+  Error (Format.sprintf "NNet format error: %s condition not satisfied." s)
+
+let nnet_delimiter = Str.regexp ","
+
+(* Parse a single NNet format line: split line using [nnet_delimiter] as
+   delimiter, and convert each string into a number by means of converter [f]. *)
+let handle_nnet_line ~f line =
+  List.filter_map
+    ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
+    (Str.split nnet_delimiter line)
+
+(* Skip the header part, ie comments, of the NNet format. *)
+let handle_nnet_header filename in_channel =
+  let exception End_of_header in
+  let pos_in = ref (Stdlib.pos_in in_channel) in
+  try
+    while true do
+      let line = Stdlib.input_line in_channel in
+      if not (Str.string_match (Str.regexp "//") line 0) then
+        raise End_of_header
+      else pos_in := Stdlib.pos_in in_channel
+    done;
+    assert false
+  with
+  | End_of_header ->
+      (* At this point we have read one line past the header part: seek back. *)
+      Stdlib.seek_in in_channel !pos_in;
+      Ok ()
+  | End_of_file ->
+      Error (Format.sprintf "NNet model not found in file `%s'." filename)
+
+(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
+let handle_nnet_basic_info in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    match handle_nnet_line ~f:Stdlib.int_of_string line with
+    | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
+        Ok (n_layers, n_inputs, n_outputs, max_layer_size)
+    | _ -> nnet_format_error "second"
+  with End_of_file -> nnet_format_error "second"
+
+(* Retrieve size of each layer, including inputs and outputs. *)
+let handle_nnet_layer_sizes n_layers in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
+    if List.length layer_sizes = n_layers + 1 then Ok layer_sizes
+    else nnet_format_error "third"
+  with End_of_file -> nnet_format_error "third"
+
+(* Skip unused flag. *)
+let handle_nnet_unused_flag in_channel =
+  try
+    let _ = Stdlib.input_line in_channel in
+    Ok ()
+  with End_of_file -> nnet_format_error "forth"
+
+(* Retrive minimum values of inputs. *)
+let handle_nnet_min_input_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length min_input_values = n_inputs then Ok min_input_values
+    else nnet_format_error "fifth"
+  with End_of_file -> nnet_format_error "fifth"
+
+(* Retrive maximum values of inputs. *)
+let handle_nnet_max_input_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length max_input_values = n_inputs then Ok max_input_values
+    else nnet_format_error "sixth"
+  with End_of_file -> nnet_format_error "sixth"
+
+(* Retrieve mean values of inputs and one value for all outputs. *)
+let handle_nnet_mean_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length mean_values = n_inputs + 1 then
+      let mean_input_values, mean_output_value =
+        List.split_n mean_values n_inputs
+      in
+      Ok (mean_input_values, List.hd_exn mean_output_value)
+    else nnet_format_error "seventh"
+  with End_of_file -> nnet_format_error "seventh"
+
+(* Retrieve range values of inputs and one value for all outputs. *)
+let handle_nnet_range_values n_inputs in_channel =
+  try
+    let line = Stdlib.input_line in_channel in
+    let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    if List.length range_values = n_inputs + 1 then
+      let range_input_values, range_output_value =
+        List.split_n range_values n_inputs
+      in
+      Ok (range_input_values, List.hd_exn range_output_value)
+    else nnet_format_error "eighth"
+  with End_of_file -> nnet_format_error "eighth"
+
+(* Retrieves [filename] NNet model metadata wrt NNet format specification (see
+   https://github.com/sisl/NNet for details.) *)
+let parse_metadata filename =
+  let open Result in
+  let in_channel = Stdlib.open_in filename in
+  try
+    handle_nnet_header filename in_channel >>= fun () ->
+    handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
+    handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
+    handle_nnet_unused_flag in_channel >>= fun () ->
+    handle_nnet_min_input_values n_is in_channel >>= fun min_input_values ->
+    handle_nnet_max_input_values n_is in_channel >>= fun max_input_values ->
+    handle_nnet_mean_values n_is in_channel >>= fun mean_values ->
+    handle_nnet_range_values n_is in_channel >>= fun range_values ->
+    Stdlib.close_in in_channel;
+    Ok
+      {
+        n_layers = n_ls;
+        n_inputs = n_is;
+        n_outputs = n_os;
+        max_layer_size = max_l_size;
+        layer_sizes;
+        min_input_values;
+        max_input_values;
+        mean_values;
+        range_values;
+      }
+  with Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli
new file mode 100644
index 00000000..b718de82
--- /dev/null
+++ b/lib/nnet/nnet.mli
@@ -0,0 +1,24 @@
+(**************************************************************************)
+(*                                                                        *)
+(*  This file is part of Caisar.                                          *)
+(*                                                                        *)
+(**************************************************************************)
+
+type t = private {
+  n_layers : int;  (** Number of layers. *)
+  n_inputs : int;  (** Number of inputs. *)
+  n_outputs : int;  (** Number of outputs. *)
+  max_layer_size : int;  (** Maximum layer size. *)
+  layer_sizes : int list;  (** Size of each layer. *)
+  min_input_values : float list;  (** Minimum values of inputs. *)
+  max_input_values : float list;  (** Maximum values of inputs. *)
+  mean_values : float list * float;
+      (** Mean values of inputs and one value for all outputs. *)
+  range_values : float list * float;
+      (** Range values of inputs and one value for all outputs. *)
+}
+[@@deriving show { with_path = false }]
+(** NNet model metadata. *)
+
+val parse_metadata : string -> (t, string) Result.t
+(** Parse an NNet file for metadata. *)
diff --git a/nnet.opam b/nnet.opam
new file mode 100644
index 00000000..af0a8be5
--- /dev/null
+++ b/nnet.opam
@@ -0,0 +1,24 @@
+# This file is generated by dune, edit dune-project instead
+opam-version: "2.0"
+version: "0.1"
+synopsis: "NNet parser"
+depends: [
+  "ocaml" {>= "4.10"}
+  "dune" {>= "2.7" & >= "2.7.1"}
+  "base" {>= "v0.14.0"}
+  "odoc" {with-doc}
+]
+build: [
+  ["dune" "subst"] {dev}
+  [
+    "dune"
+    "build"
+    "-p"
+    name
+    "-j"
+    jobs
+    "@install"
+    "@runtest" {with-test}
+    "@doc" {with-doc}
+  ]
+]
diff --git a/src/dune b/src/dune
new file mode 100644
index 00000000..0c6fcbae
--- /dev/null
+++ b/src/dune
@@ -0,0 +1,7 @@
+(executable
+  (name main)
+  (public_name caisar)
+  (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3)
+ (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
+ (package caisar)
+)
diff --git a/src/main.ml b/src/main.ml
new file mode 100644
index 00000000..e69de29b
diff --git a/standalone/dune b/standalone/dune
index 4ec3915a..7eea4820 100644
--- a/standalone/dune
+++ b/standalone/dune
@@ -7,7 +7,9 @@
 
 (executable
   (name main)
-  (public_name caisar)
+  (public_name caisar-standalone)
   (modules_without_implementation property_types)
   (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime)
-  (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)))
+  (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
+  (package caisar)
+)
-- 
GitLab