diff --git a/lib/nnet/dune b/lib/nnet/dune index c15dbf7657c947f6e8612b6eebc43618a7526576..fb2c5ec74be4294d311033bf92eed89bf913e5da 100644 --- a/lib/nnet/dune +++ b/lib/nnet/dune @@ -1,5 +1,5 @@ (library (name nnet) (public_name nnet) - (libraries base) + (libraries base csv) (synopsis "NNet parser")) diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml index be5bf793b3cabcd1df24a4c263ecb8cea5e02beb..7429c0a53ba17734469d23e2939a1dcb580a8b61 100644 --- a/lib/nnet/nnet.ml +++ b/lib/nnet/nnet.ml @@ -19,25 +19,23 @@ type t = { max_input_values : float list; mean_values : float list * float; range_values : float list * float; + weights_biases : float list list; } -[@@deriving show { with_path = false }] (* NNet format handling. *) let nnet_format_error s = Error (Format.sprintf "NNet format error: %s condition not satisfied." s) -let nnet_delimiter = Str.regexp "," - -(* Parse a single NNet format line: split line using [nnet_delimiter] as - delimiter, and convert each string into a number by means of converter [f]. *) -let handle_nnet_line ~f line = +(* Parse a single NNet format line: split line wrt CSV format, and convert each + string into a number by means of converter [f]. *) +let handle_nnet_line ~f in_channel = List.filter_map ~f:(fun s -> try Some (f (String.strip s)) with _ -> None) - (Str.split nnet_delimiter line) + (Csv.next in_channel) (* Skip the header part, ie comments, of the NNet format. *) -let handle_nnet_header filename in_channel = +let skip_nnet_header filename in_channel = let exception End_of_header in let pos_in = ref (Stdlib.pos_in in_channel) in try @@ -58,19 +56,16 @@ let handle_nnet_header filename in_channel = (* Retrieve number of layers, inputs, outputs and maximum layer size. *) let handle_nnet_basic_info in_channel = - try - let line = Stdlib.input_line in_channel in - match handle_nnet_line ~f:Stdlib.int_of_string line with - | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> - Ok (n_layers, n_inputs, n_outputs, max_layer_size) - | _ -> nnet_format_error "second" - with End_of_file -> nnet_format_error "second" + match handle_nnet_line ~f:Int.of_string in_channel with + | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> + Ok (n_layers, n_inputs, n_outputs, max_layer_size) + | _ -> nnet_format_error "second" + | exception End_of_file -> nnet_format_error "second" (* Retrieve size of each layer, including inputs and outputs. *) let handle_nnet_layer_sizes n_layers in_channel = try - let line = Stdlib.input_line in_channel in - let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in + let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in if List.length layer_sizes = n_layers + 1 then Ok layer_sizes else nnet_format_error "third" with End_of_file -> nnet_format_error "third" @@ -78,15 +73,14 @@ let handle_nnet_layer_sizes n_layers in_channel = (* Skip unused flag. *) let handle_nnet_unused_flag in_channel = try - let _ = Stdlib.input_line in_channel in + let _ = Csv.next in_channel in Ok () with End_of_file -> nnet_format_error "forth" (* Retrive minimum values of inputs. *) let handle_nnet_min_input_values n_inputs in_channel = try - let line = Stdlib.input_line in_channel in - let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length min_input_values = n_inputs then Ok min_input_values else nnet_format_error "fifth" with End_of_file -> nnet_format_error "fifth" @@ -94,8 +88,7 @@ let handle_nnet_min_input_values n_inputs in_channel = (* Retrive maximum values of inputs. *) let handle_nnet_max_input_values n_inputs in_channel = try - let line = Stdlib.input_line in_channel in - let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length max_input_values = n_inputs then Ok max_input_values else nnet_format_error "sixth" with End_of_file -> nnet_format_error "sixth" @@ -103,8 +96,7 @@ let handle_nnet_max_input_values n_inputs in_channel = (* Retrieve mean values of inputs and one value for all outputs. *) let handle_nnet_mean_values n_inputs in_channel = try - let line = Stdlib.input_line in_channel in - let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in + let mean_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length mean_values = n_inputs + 1 then let mean_input_values, mean_output_value = List.split_n mean_values n_inputs @@ -116,8 +108,7 @@ let handle_nnet_mean_values n_inputs in_channel = (* Retrieve range values of inputs and one value for all outputs. *) let handle_nnet_range_values n_inputs in_channel = try - let line = Stdlib.input_line in_channel in - let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in + let range_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length range_values = n_inputs + 1 then let range_input_values, range_output_value = List.split_n range_values n_inputs @@ -126,13 +117,27 @@ let handle_nnet_range_values n_inputs in_channel = else nnet_format_error "eighth" with End_of_file -> nnet_format_error "eighth" -(* Retrieves [filename] NNet model metadata wrt NNet format specification (see - https://github.com/sisl/NNet for details.) *) -let parse_metadata filename = +(* Retrieve all layer weights and biases as appearing in the model. No special + treatment is performed. *) +let handle_nnet_weights_and_biases in_channel = + List.rev + (Csv.fold_left ~init:[] + ~f:(fun fll sl -> + List.filter_map + ~f:(fun s -> + try Some (Float.of_string (String.strip s)) with _ -> None) + sl + :: fll) + in_channel) + +(* Retrieves [filename] NNet model metadata and weights wrt NNet format + specification (see https://github.com/sisl/NNet for details). *) +let parse filename = let open Result in - let in_channel = Stdlib.open_in filename in try - handle_nnet_header filename in_channel >>= fun () -> + let in_channel = Stdlib.open_in filename in + skip_nnet_header filename in_channel >>= fun () -> + let in_channel = Csv.of_channel in_channel in handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> handle_nnet_unused_flag in_channel >>= fun () -> @@ -140,7 +145,8 @@ let parse_metadata filename = handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> handle_nnet_mean_values n_is in_channel >>= fun mean_values -> handle_nnet_range_values n_is in_channel >>= fun range_values -> - Stdlib.close_in in_channel; + let weights_biases = handle_nnet_weights_and_biases in_channel in + Csv.close_in in_channel; Ok { n_layers = n_ls; @@ -152,5 +158,9 @@ let parse_metadata filename = max_input_values; mean_values; range_values; + weights_biases; } - with Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) + with + | Csv.Failure (_nrecord, _nfield, msg) -> Error msg + | Sys_error s -> Error s + | Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli index b718de82b47e212ad65103995899a915c8a4be7d..a1b5076fefaccfcea806f896708b3d48260d0eeb 100644 --- a/lib/nnet/nnet.mli +++ b/lib/nnet/nnet.mli @@ -16,9 +16,9 @@ type t = private { (** Mean values of inputs and one value for all outputs. *) range_values : float list * float; (** Range values of inputs and one value for all outputs. *) + weights_biases : float list list; (** All weights and biases of NNet model. *) } -[@@deriving show { with_path = false }] (** NNet model metadata. *) -val parse_metadata : string -> (t, string) Result.t -(** Parse an NNet file for metadata. *) +val parse : string -> (t, string) Result.t +(** Parse an NNet file. *) diff --git a/src/dune b/src/dune index 0c6fcbae294537ccc076196fef95ca4b981204b5..14cf47c969863f82832fa57606230984451721b6 100644 --- a/src/dune +++ b/src/dune @@ -2,6 +2,6 @@ (name main) (public_name caisar) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3) - (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) - (package caisar) + (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) + (package caisar) )