diff --git a/.ocamlformat b/.ocamlformat new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/dune-project b/dune-project index d165345231b0a3b858084fe81e6931a5723a10f4..7698493621f1e31384bcce0b36193e46243e4a55 100644 --- a/dune-project +++ b/dune-project @@ -22,3 +22,13 @@ (ppx_deriving_yojson (>= 3.6.1)) ) ) + +(package + (name nnet) + (synopsis "NNet parser") + (depends + (ocaml (>= 4.10)) + (dune (>= 2.7.1)) + (base (>= v0.14.0)) + ) +) diff --git a/lib/nnet/dune b/lib/nnet/dune new file mode 100644 index 0000000000000000000000000000000000000000..c15dbf7657c947f6e8612b6eebc43618a7526576 --- /dev/null +++ b/lib/nnet/dune @@ -0,0 +1,5 @@ +(library + (name nnet) + (public_name nnet) + (libraries base) + (synopsis "NNet parser")) diff --git a/lib/nnet/nnet.ml b/lib/nnet/nnet.ml new file mode 100644 index 0000000000000000000000000000000000000000..be5bf793b3cabcd1df24a4c263ecb8cea5e02beb --- /dev/null +++ b/lib/nnet/nnet.ml @@ -0,0 +1,156 @@ +(**************************************************************************) +(* *) +(* This file is part of Caisar. *) +(* *) +(**************************************************************************) + +open Base +module Format = Caml.Format +module Sys = Caml.Sys +module Filename = Caml.Filename + +type t = { + n_layers : int; + n_inputs : int; + n_outputs : int; + max_layer_size : int; + layer_sizes : int list; + min_input_values : float list; + max_input_values : float list; + mean_values : float list * float; + range_values : float list * float; +} +[@@deriving show { with_path = false }] + +(* NNet format handling. *) + +let nnet_format_error s = + Error (Format.sprintf "NNet format error: %s condition not satisfied." s) + +let nnet_delimiter = Str.regexp "," + +(* Parse a single NNet format line: split line using [nnet_delimiter] as + delimiter, and convert each string into a number by means of converter [f]. *) +let handle_nnet_line ~f line = + List.filter_map + ~f:(fun s -> try Some (f (String.strip s)) with _ -> None) + (Str.split nnet_delimiter line) + +(* Skip the header part, ie comments, of the NNet format. *) +let handle_nnet_header filename in_channel = + let exception End_of_header in + let pos_in = ref (Stdlib.pos_in in_channel) in + try + while true do + let line = Stdlib.input_line in_channel in + if not (Str.string_match (Str.regexp "//") line 0) then + raise End_of_header + else pos_in := Stdlib.pos_in in_channel + done; + assert false + with + | End_of_header -> + (* At this point we have read one line past the header part: seek back. *) + Stdlib.seek_in in_channel !pos_in; + Ok () + | End_of_file -> + Error (Format.sprintf "NNet model not found in file `%s'." filename) + +(* Retrieve number of layers, inputs, outputs and maximum layer size. *) +let handle_nnet_basic_info in_channel = + try + let line = Stdlib.input_line in_channel in + match handle_nnet_line ~f:Stdlib.int_of_string line with + | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> + Ok (n_layers, n_inputs, n_outputs, max_layer_size) + | _ -> nnet_format_error "second" + with End_of_file -> nnet_format_error "second" + +(* Retrieve size of each layer, including inputs and outputs. *) +let handle_nnet_layer_sizes n_layers in_channel = + try + let line = Stdlib.input_line in_channel in + let layer_sizes = handle_nnet_line ~f:Stdlib.int_of_string line in + if List.length layer_sizes = n_layers + 1 then Ok layer_sizes + else nnet_format_error "third" + with End_of_file -> nnet_format_error "third" + +(* Skip unused flag. *) +let handle_nnet_unused_flag in_channel = + try + let _ = Stdlib.input_line in_channel in + Ok () + with End_of_file -> nnet_format_error "forth" + +(* Retrive minimum values of inputs. *) +let handle_nnet_min_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let min_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length min_input_values = n_inputs then Ok min_input_values + else nnet_format_error "fifth" + with End_of_file -> nnet_format_error "fifth" + +(* Retrive maximum values of inputs. *) +let handle_nnet_max_input_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let max_input_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length max_input_values = n_inputs then Ok max_input_values + else nnet_format_error "sixth" + with End_of_file -> nnet_format_error "sixth" + +(* Retrieve mean values of inputs and one value for all outputs. *) +let handle_nnet_mean_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let mean_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length mean_values = n_inputs + 1 then + let mean_input_values, mean_output_value = + List.split_n mean_values n_inputs + in + Ok (mean_input_values, List.hd_exn mean_output_value) + else nnet_format_error "seventh" + with End_of_file -> nnet_format_error "seventh" + +(* Retrieve range values of inputs and one value for all outputs. *) +let handle_nnet_range_values n_inputs in_channel = + try + let line = Stdlib.input_line in_channel in + let range_values = handle_nnet_line ~f:Stdlib.float_of_string line in + if List.length range_values = n_inputs + 1 then + let range_input_values, range_output_value = + List.split_n range_values n_inputs + in + Ok (range_input_values, List.hd_exn range_output_value) + else nnet_format_error "eighth" + with End_of_file -> nnet_format_error "eighth" + +(* Retrieves [filename] NNet model metadata wrt NNet format specification (see + https://github.com/sisl/NNet for details.) *) +let parse_metadata filename = + let open Result in + let in_channel = Stdlib.open_in filename in + try + handle_nnet_header filename in_channel >>= fun () -> + handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> + handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> + handle_nnet_unused_flag in_channel >>= fun () -> + handle_nnet_min_input_values n_is in_channel >>= fun min_input_values -> + handle_nnet_max_input_values n_is in_channel >>= fun max_input_values -> + handle_nnet_mean_values n_is in_channel >>= fun mean_values -> + handle_nnet_range_values n_is in_channel >>= fun range_values -> + Stdlib.close_in in_channel; + Ok + { + n_layers = n_ls; + n_inputs = n_is; + n_outputs = n_os; + max_layer_size = max_l_size; + layer_sizes; + min_input_values; + max_input_values; + mean_values; + range_values; + } + with Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg) diff --git a/lib/nnet/nnet.mli b/lib/nnet/nnet.mli new file mode 100644 index 0000000000000000000000000000000000000000..b718de82b47e212ad65103995899a915c8a4be7d --- /dev/null +++ b/lib/nnet/nnet.mli @@ -0,0 +1,24 @@ +(**************************************************************************) +(* *) +(* This file is part of Caisar. *) +(* *) +(**************************************************************************) + +type t = private { + n_layers : int; (** Number of layers. *) + n_inputs : int; (** Number of inputs. *) + n_outputs : int; (** Number of outputs. *) + max_layer_size : int; (** Maximum layer size. *) + layer_sizes : int list; (** Size of each layer. *) + min_input_values : float list; (** Minimum values of inputs. *) + max_input_values : float list; (** Maximum values of inputs. *) + mean_values : float list * float; + (** Mean values of inputs and one value for all outputs. *) + range_values : float list * float; + (** Range values of inputs and one value for all outputs. *) +} +[@@deriving show { with_path = false }] +(** NNet model metadata. *) + +val parse_metadata : string -> (t, string) Result.t +(** Parse an NNet file for metadata. *) diff --git a/nnet.opam b/nnet.opam new file mode 100644 index 0000000000000000000000000000000000000000..af0a8be51a7c11d9b69ee38af2423681143a4537 --- /dev/null +++ b/nnet.opam @@ -0,0 +1,24 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +version: "0.1" +synopsis: "NNet parser" +depends: [ + "ocaml" {>= "4.10"} + "dune" {>= "2.7" & >= "2.7.1"} + "base" {>= "v0.14.0"} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] diff --git a/src/dune b/src/dune new file mode 100644 index 0000000000000000000000000000000000000000..0c6fcbae294537ccc076196fef95ca4b981204b5 --- /dev/null +++ b/src/dune @@ -0,0 +1,7 @@ +(executable + (name main) + (public_name caisar) + (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3) + (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) + (package caisar) +) diff --git a/src/main.ml b/src/main.ml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/standalone/dune b/standalone/dune index 4ec3915ae58ca54911ecd937558a56e94d8a84aa..7eea48208caf40bedaa33bff153a7a2118b9286f 100644 --- a/standalone/dune +++ b/standalone/dune @@ -7,7 +7,9 @@ (executable (name main) - (public_name caisar) + (public_name caisar-standalone) (modules_without_implementation property_types) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime) - (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))) + (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) + (package caisar) +)