From 377c194efefe1982f37e9e3a5a1ffd93f7a47b90 Mon Sep 17 00:00:00 2001
From: Michele Alberti <michele.alberti@cea.fr>
Date: Wed, 30 Jun 2021 16:43:37 +0200
Subject: [PATCH] Use csv library to parse nnet model wrt CSV format.

---
 lib/nnet/dune     |  2 +-
 lib/nnet/nnet.ml  | 76 +++++++++++++++++++++++++++--------------------
 lib/nnet/nnet.mli |  6 ++--
 src/dune          |  4 +--
 4 files changed, 49 insertions(+), 39 deletions(-)

diff --git a/lib/nnet/dune b/lib/nnet/dune
index c15dbf76..fb2c5ec7 100644
--- a/lib/nnet/dune
+++ b/lib/nnet/dune
@@ -1,5 +1,5 @@
 (library
  (name        nnet)
  (public_name nnet)
- (libraries base)
+ (libraries base csv)
  (synopsis "NNet parser"))
diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml
index be5bf793..7429c0a5 100644
--- a/lib/nnet/nnet.ml
+++ b/lib/nnet/nnet.ml
@@ -19,25 +19,23 @@ type t = {
   max_input_values : float list;
   mean_values : float list * float;
   range_values : float list * float;
+  weights_biases : float list list;
 }
-[@@deriving show { with_path = false }]
 
 (* NNet format handling. *)
 
 let nnet_format_error s =
   Error (Format.sprintf "NNet format error: %s condition not satisfied." s)
 
-let nnet_delimiter = Str.regexp ","
-
-(* Parse a single NNet format line: split line using [nnet_delimiter] as
-   delimiter, and convert each string into a number by means of converter [f]. *)
-let handle_nnet_line ~f line =
+(* Parse a single NNet format line: split line wrt CSV format, and convert each
+   string into a number by means of converter [f]. *)
+let handle_nnet_line ~f in_channel =
   List.filter_map
     ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
-    (Str.split nnet_delimiter line)
+    (Csv.next in_channel)
 
 (* Skip the header part, ie comments, of the NNet format. *)
-let handle_nnet_header filename in_channel =
+let skip_nnet_header filename in_channel =
   let exception End_of_header in
   let pos_in = ref (Stdlib.pos_in in_channel) in
   try
@@ -58,19 +56,16 @@ let handle_nnet_header filename in_channel =
 
 (* Retrieve number of layers, inputs, outputs and maximum layer size. *)
 let handle_nnet_basic_info in_channel =
-  try
-    let line = Stdlib.input_line in_channel in
-    match handle_nnet_line ~f:Stdlib.int_of_string line with
-    | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
-        Ok (n_layers, n_inputs, n_outputs, max_layer_size)
-    | _ -> nnet_format_error "second"
-  with End_of_file -> nnet_format_error "second"
+  match handle_nnet_line ~f:Int.of_string in_channel with
+  | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
+      Ok (n_layers, n_inputs, n_outputs, max_layer_size)
+  | _ -> nnet_format_error "second"
+  | exception End_of_file -> nnet_format_error "second"
 
 (* Retrieve size of each layer, including inputs and outputs. *)
 let handle_nnet_layer_sizes n_layers in_channel =
   try
-    let line = Stdlib.input_line in_channel in
-    let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in
+    let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in
     if List.length layer_sizes = n_layers + 1 then Ok layer_sizes
     else nnet_format_error "third"
   with End_of_file -> nnet_format_error "third"
@@ -78,15 +73,14 @@ let handle_nnet_layer_sizes n_layers in_channel =
 (* Skip unused flag. *)
 let handle_nnet_unused_flag in_channel =
   try
-    let _ = Stdlib.input_line in_channel in
+    let _ = Csv.next in_channel in
     Ok ()
   with End_of_file -> nnet_format_error "forth"
 
 (* Retrive minimum values of inputs. *)
 let handle_nnet_min_input_values n_inputs in_channel =
   try
-    let line = Stdlib.input_line in_channel in
-    let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in
     if List.length min_input_values = n_inputs then Ok min_input_values
     else nnet_format_error "fifth"
   with End_of_file -> nnet_format_error "fifth"
@@ -94,8 +88,7 @@ let handle_nnet_min_input_values n_inputs in_channel =
 (* Retrive maximum values of inputs. *)
 let handle_nnet_max_input_values n_inputs in_channel =
   try
-    let line = Stdlib.input_line in_channel in
-    let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in
     if List.length max_input_values = n_inputs then Ok max_input_values
     else nnet_format_error "sixth"
   with End_of_file -> nnet_format_error "sixth"
@@ -103,8 +96,7 @@ let handle_nnet_max_input_values n_inputs in_channel =
 (* Retrieve mean values of inputs and one value for all outputs. *)
 let handle_nnet_mean_values n_inputs in_channel =
   try
-    let line = Stdlib.input_line in_channel in
-    let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    let mean_values = handle_nnet_line ~f:Float.of_string in_channel in
     if List.length mean_values = n_inputs + 1 then
       let mean_input_values, mean_output_value =
         List.split_n mean_values n_inputs
@@ -116,8 +108,7 @@ let handle_nnet_mean_values n_inputs in_channel =
 (* Retrieve range values of inputs and one value for all outputs. *)
 let handle_nnet_range_values n_inputs in_channel =
   try
-    let line = Stdlib.input_line in_channel in
-    let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in
+    let range_values = handle_nnet_line ~f:Float.of_string in_channel in
     if List.length range_values = n_inputs + 1 then
       let range_input_values, range_output_value =
         List.split_n range_values n_inputs
@@ -126,13 +117,27 @@ let handle_nnet_range_values n_inputs in_channel =
     else nnet_format_error "eighth"
   with End_of_file -> nnet_format_error "eighth"
 
-(* Retrieves [filename] NNet model metadata wrt NNet format specification (see
-   https://github.com/sisl/NNet for details.) *)
-let parse_metadata filename =
+(* Retrieve all layer weights and biases as appearing in the model. No special
+   treatment is performed. *)
+let handle_nnet_weights_and_biases in_channel =
+  List.rev
+    (Csv.fold_left ~init:[]
+       ~f:(fun fll sl ->
+         List.filter_map
+           ~f:(fun s ->
+             try Some (Float.of_string (String.strip s)) with _ -> None)
+           sl
+         :: fll)
+       in_channel)
+
+(* Retrieves [filename] NNet model metadata and weights wrt NNet format
+   specification (see https://github.com/sisl/NNet for details). *)
+let parse filename =
   let open Result in
-  let in_channel = Stdlib.open_in filename in
   try
-    handle_nnet_header filename in_channel >>= fun () ->
+    let in_channel = Stdlib.open_in filename in
+    skip_nnet_header filename in_channel >>= fun () ->
+    let in_channel = Csv.of_channel in_channel in
     handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
     handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
     handle_nnet_unused_flag in_channel >>= fun () ->
@@ -140,7 +145,8 @@ let parse_metadata filename =
     handle_nnet_max_input_values n_is in_channel >>= fun max_input_values ->
     handle_nnet_mean_values n_is in_channel >>= fun mean_values ->
     handle_nnet_range_values n_is in_channel >>= fun range_values ->
-    Stdlib.close_in in_channel;
+    let weights_biases = handle_nnet_weights_and_biases in_channel in
+    Csv.close_in in_channel;
     Ok
       {
         n_layers = n_ls;
@@ -152,5 +158,9 @@ let parse_metadata filename =
         max_input_values;
         mean_values;
         range_values;
+        weights_biases;
       }
-  with Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
+  with
+  | Csv.Failure (_nrecord, _nfield, msg) -> Error msg
+  | Sys_error s -> Error s
+  | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli
index b718de82..a1b5076f 100644
--- a/lib/nnet/nnet.mli
+++ b/lib/nnet/nnet.mli
@@ -16,9 +16,9 @@ type t = private {
       (** Mean values of inputs and one value for all outputs. *)
   range_values : float list * float;
       (** Range values of inputs and one value for all outputs. *)
+  weights_biases : float list list;  (** All weights and biases of NNet model. *)
 }
-[@@deriving show { with_path = false }]
 (** NNet model metadata. *)
 
-val parse_metadata : string -> (t, string) Result.t
-(** Parse an NNet file for metadata. *)
+val parse : string -> (t, string) Result.t
+(** Parse an NNet file. *)
diff --git a/src/dune b/src/dune
index 0c6fcbae..14cf47c9 100644
--- a/src/dune
+++ b/src/dune
@@ -2,6 +2,6 @@
   (name main)
   (public_name caisar)
   (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3)
- (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
- (package caisar)
+  (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
+  (package caisar)
 )
-- 
GitLab