Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
caisar
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pub
caisar
Commits
377c194e
Commit
377c194e
authored
3 years ago
by
Michele Alberti
Browse files
Options
Downloads
Patches
Plain Diff
Use csv library to parse nnet model wrt CSV format.
parent
9aceca1a
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
lib/nnet/dune
+1
-1
1 addition, 1 deletion
lib/nnet/dune
lib/nnet/nnet.ml
+43
-33
43 additions, 33 deletions
lib/nnet/nnet.ml
lib/nnet/nnet.mli
+3
-3
3 additions, 3 deletions
lib/nnet/nnet.mli
src/dune
+2
-2
2 additions, 2 deletions
src/dune
with
49 additions
and
39 deletions
lib/nnet/dune
+
1
−
1
View file @
377c194e
(library
(library
(name nnet)
(name nnet)
(public_name nnet)
(public_name nnet)
(libraries base)
(libraries base
csv
)
(synopsis "NNet parser"))
(synopsis "NNet parser"))
This diff is collapsed.
Click to expand it.
lib/nnet/nnet.ml
+
43
−
33
View file @
377c194e
...
@@ -19,25 +19,23 @@ type t = {
...
@@ -19,25 +19,23 @@ type t = {
max_input_values
:
float
list
;
max_input_values
:
float
list
;
mean_values
:
float
list
*
float
;
mean_values
:
float
list
*
float
;
range_values
:
float
list
*
float
;
range_values
:
float
list
*
float
;
weights_biases
:
float
list
list
;
}
}
[
@@
deriving
show
{
with_path
=
false
}]
(* NNet format handling. *)
(* NNet format handling. *)
let
nnet_format_error
s
=
let
nnet_format_error
s
=
Error
(
Format
.
sprintf
"NNet format error: %s condition not satisfied."
s
)
Error
(
Format
.
sprintf
"NNet format error: %s condition not satisfied."
s
)
let
nnet_delimiter
=
Str
.
regexp
","
(* Parse a single NNet format line: split line wrt CSV format, and convert each
string into a number by means of converter [f]. *)
(* Parse a single NNet format line: split line using [nnet_delimiter] as
let
handle_nnet_line
~
f
in_channel
=
delimiter, and convert each string into a number by means of converter [f]. *)
let
handle_nnet_line
~
f
line
=
List
.
filter_map
List
.
filter_map
~
f
:
(
fun
s
->
try
Some
(
f
(
String
.
strip
s
))
with
_
->
None
)
~
f
:
(
fun
s
->
try
Some
(
f
(
String
.
strip
s
))
with
_
->
None
)
(
Str
.
split
nnet_delimiter
line
)
(
Csv
.
next
in_channel
)
(* Skip the header part, ie comments, of the NNet format. *)
(* Skip the header part, ie comments, of the NNet format. *)
let
handle
_nnet_header
filename
in_channel
=
let
skip
_nnet_header
filename
in_channel
=
let
exception
End_of_header
in
let
exception
End_of_header
in
let
pos_in
=
ref
(
Stdlib
.
pos_in
in_channel
)
in
let
pos_in
=
ref
(
Stdlib
.
pos_in
in_channel
)
in
try
try
...
@@ -58,19 +56,16 @@ let handle_nnet_header filename in_channel =
...
@@ -58,19 +56,16 @@ let handle_nnet_header filename in_channel =
(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let
handle_nnet_basic_info
in_channel
=
let
handle_nnet_basic_info
in_channel
=
try
match
handle_nnet_line
~
f
:
Int
.
of_string
in_channel
with
let
line
=
Stdlib
.
input_line
in_channel
in
|
[
n_layers
;
n_inputs
;
n_outputs
;
max_layer_size
]
->
match
handle_nnet_line
~
f
:
Stdlib
.
int_of_string
line
with
Ok
(
n_layers
,
n_inputs
,
n_outputs
,
max_layer_size
)
|
[
n_layers
;
n_inputs
;
n_outputs
;
max_layer_size
]
->
|
_
->
nnet_format_error
"second"
Ok
(
n_layers
,
n_inputs
,
n_outputs
,
max_layer_size
)
|
exception
End_of_file
->
nnet_format_error
"second"
|
_
->
nnet_format_error
"second"
with
End_of_file
->
nnet_format_error
"second"
(* Retrieve size of each layer, including inputs and outputs. *)
(* Retrieve size of each layer, including inputs and outputs. *)
let
handle_nnet_layer_sizes
n_layers
in_channel
=
let
handle_nnet_layer_sizes
n_layers
in_channel
=
try
try
let
line
=
Stdlib
.
input_line
in_channel
in
let
layer_sizes
=
handle_nnet_line
~
f
:
Int
.
of_string
in_channel
in
let
layer_sizes
=
handle_nnet_line
~
f
:
Stdlib
.
int_of_string
line
in
if
List
.
length
layer_sizes
=
n_layers
+
1
then
Ok
layer_sizes
if
List
.
length
layer_sizes
=
n_layers
+
1
then
Ok
layer_sizes
else
nnet_format_error
"third"
else
nnet_format_error
"third"
with
End_of_file
->
nnet_format_error
"third"
with
End_of_file
->
nnet_format_error
"third"
...
@@ -78,15 +73,14 @@ let handle_nnet_layer_sizes n_layers in_channel =
...
@@ -78,15 +73,14 @@ let handle_nnet_layer_sizes n_layers in_channel =
(* Skip unused flag. *)
(* Skip unused flag. *)
let
handle_nnet_unused_flag
in_channel
=
let
handle_nnet_unused_flag
in_channel
=
try
try
let
_
=
Stdlib
.
input_line
in_channel
in
let
_
=
Csv
.
next
in_channel
in
Ok
()
Ok
()
with
End_of_file
->
nnet_format_error
"forth"
with
End_of_file
->
nnet_format_error
"forth"
(* Retrive minimum values of inputs. *)
(* Retrive minimum values of inputs. *)
let
handle_nnet_min_input_values
n_inputs
in_channel
=
let
handle_nnet_min_input_values
n_inputs
in_channel
=
try
try
let
line
=
Stdlib
.
input_line
in_channel
in
let
min_input_values
=
handle_nnet_line
~
f
:
Float
.
of_string
in_channel
in
let
min_input_values
=
handle_nnet_line
~
f
:
Stdlib
.
float_of_string
line
in
if
List
.
length
min_input_values
=
n_inputs
then
Ok
min_input_values
if
List
.
length
min_input_values
=
n_inputs
then
Ok
min_input_values
else
nnet_format_error
"fifth"
else
nnet_format_error
"fifth"
with
End_of_file
->
nnet_format_error
"fifth"
with
End_of_file
->
nnet_format_error
"fifth"
...
@@ -94,8 +88,7 @@ let handle_nnet_min_input_values n_inputs in_channel =
...
@@ -94,8 +88,7 @@ let handle_nnet_min_input_values n_inputs in_channel =
(* Retrive maximum values of inputs. *)
(* Retrive maximum values of inputs. *)
let
handle_nnet_max_input_values
n_inputs
in_channel
=
let
handle_nnet_max_input_values
n_inputs
in_channel
=
try
try
let
line
=
Stdlib
.
input_line
in_channel
in
let
max_input_values
=
handle_nnet_line
~
f
:
Float
.
of_string
in_channel
in
let
max_input_values
=
handle_nnet_line
~
f
:
Stdlib
.
float_of_string
line
in
if
List
.
length
max_input_values
=
n_inputs
then
Ok
max_input_values
if
List
.
length
max_input_values
=
n_inputs
then
Ok
max_input_values
else
nnet_format_error
"sixth"
else
nnet_format_error
"sixth"
with
End_of_file
->
nnet_format_error
"sixth"
with
End_of_file
->
nnet_format_error
"sixth"
...
@@ -103,8 +96,7 @@ let handle_nnet_max_input_values n_inputs in_channel =
...
@@ -103,8 +96,7 @@ let handle_nnet_max_input_values n_inputs in_channel =
(* Retrieve mean values of inputs and one value for all outputs. *)
(* Retrieve mean values of inputs and one value for all outputs. *)
let
handle_nnet_mean_values
n_inputs
in_channel
=
let
handle_nnet_mean_values
n_inputs
in_channel
=
try
try
let
line
=
Stdlib
.
input_line
in_channel
in
let
mean_values
=
handle_nnet_line
~
f
:
Float
.
of_string
in_channel
in
let
mean_values
=
handle_nnet_line
~
f
:
Stdlib
.
float_of_string
line
in
if
List
.
length
mean_values
=
n_inputs
+
1
then
if
List
.
length
mean_values
=
n_inputs
+
1
then
let
mean_input_values
,
mean_output_value
=
let
mean_input_values
,
mean_output_value
=
List
.
split_n
mean_values
n_inputs
List
.
split_n
mean_values
n_inputs
...
@@ -116,8 +108,7 @@ let handle_nnet_mean_values n_inputs in_channel =
...
@@ -116,8 +108,7 @@ let handle_nnet_mean_values n_inputs in_channel =
(* Retrieve range values of inputs and one value for all outputs. *)
(* Retrieve range values of inputs and one value for all outputs. *)
let
handle_nnet_range_values
n_inputs
in_channel
=
let
handle_nnet_range_values
n_inputs
in_channel
=
try
try
let
line
=
Stdlib
.
input_line
in_channel
in
let
range_values
=
handle_nnet_line
~
f
:
Float
.
of_string
in_channel
in
let
range_values
=
handle_nnet_line
~
f
:
Stdlib
.
float_of_string
line
in
if
List
.
length
range_values
=
n_inputs
+
1
then
if
List
.
length
range_values
=
n_inputs
+
1
then
let
range_input_values
,
range_output_value
=
let
range_input_values
,
range_output_value
=
List
.
split_n
range_values
n_inputs
List
.
split_n
range_values
n_inputs
...
@@ -126,13 +117,27 @@ let handle_nnet_range_values n_inputs in_channel =
...
@@ -126,13 +117,27 @@ let handle_nnet_range_values n_inputs in_channel =
else
nnet_format_error
"eighth"
else
nnet_format_error
"eighth"
with
End_of_file
->
nnet_format_error
"eighth"
with
End_of_file
->
nnet_format_error
"eighth"
(* Retrieves [filename] NNet model metadata wrt NNet format specification (see
(* Retrieve all layer weights and biases as appearing in the model. No special
https://github.com/sisl/NNet for details.) *)
treatment is performed. *)
let
parse_metadata
filename
=
let
handle_nnet_weights_and_biases
in_channel
=
List
.
rev
(
Csv
.
fold_left
~
init
:
[]
~
f
:
(
fun
fll
sl
->
List
.
filter_map
~
f
:
(
fun
s
->
try
Some
(
Float
.
of_string
(
String
.
strip
s
))
with
_
->
None
)
sl
::
fll
)
in_channel
)
(* Retrieves [filename] NNet model metadata and weights wrt NNet format
specification (see https://github.com/sisl/NNet for details). *)
let
parse
filename
=
let
open
Result
in
let
open
Result
in
let
in_channel
=
Stdlib
.
open_in
filename
in
try
try
handle_nnet_header
filename
in_channel
>>=
fun
()
->
let
in_channel
=
Stdlib
.
open_in
filename
in
skip_nnet_header
filename
in_channel
>>=
fun
()
->
let
in_channel
=
Csv
.
of_channel
in_channel
in
handle_nnet_basic_info
in_channel
>>=
fun
(
n_ls
,
n_is
,
n_os
,
max_l_size
)
->
handle_nnet_basic_info
in_channel
>>=
fun
(
n_ls
,
n_is
,
n_os
,
max_l_size
)
->
handle_nnet_layer_sizes
n_ls
in_channel
>>=
fun
layer_sizes
->
handle_nnet_layer_sizes
n_ls
in_channel
>>=
fun
layer_sizes
->
handle_nnet_unused_flag
in_channel
>>=
fun
()
->
handle_nnet_unused_flag
in_channel
>>=
fun
()
->
...
@@ -140,7 +145,8 @@ let parse_metadata filename =
...
@@ -140,7 +145,8 @@ let parse_metadata filename =
handle_nnet_max_input_values
n_is
in_channel
>>=
fun
max_input_values
->
handle_nnet_max_input_values
n_is
in_channel
>>=
fun
max_input_values
->
handle_nnet_mean_values
n_is
in_channel
>>=
fun
mean_values
->
handle_nnet_mean_values
n_is
in_channel
>>=
fun
mean_values
->
handle_nnet_range_values
n_is
in_channel
>>=
fun
range_values
->
handle_nnet_range_values
n_is
in_channel
>>=
fun
range_values
->
Stdlib
.
close_in
in_channel
;
let
weights_biases
=
handle_nnet_weights_and_biases
in_channel
in
Csv
.
close_in
in_channel
;
Ok
Ok
{
{
n_layers
=
n_ls
;
n_layers
=
n_ls
;
...
@@ -152,5 +158,9 @@ let parse_metadata filename =
...
@@ -152,5 +158,9 @@ let parse_metadata filename =
max_input_values
;
max_input_values
;
mean_values
;
mean_values
;
range_values
;
range_values
;
weights_biases
;
}
}
with
Failure
msg
->
Error
(
Format
.
sprintf
"Unexpected error: %s."
msg
)
with
|
Csv
.
Failure
(
_nrecord
,
_nfield
,
msg
)
->
Error
msg
|
Sys_error
s
->
Error
s
|
Failure
msg
->
Error
(
Format
.
sprintf
"Unexpected error: %s."
msg
)
This diff is collapsed.
Click to expand it.
lib/nnet/nnet.mli
+
3
−
3
View file @
377c194e
...
@@ -16,9 +16,9 @@ type t = private {
...
@@ -16,9 +16,9 @@ type t = private {
(** Mean values of inputs and one value for all outputs. *)
(** Mean values of inputs and one value for all outputs. *)
range_values
:
float
list
*
float
;
range_values
:
float
list
*
float
;
(** Range values of inputs and one value for all outputs. *)
(** Range values of inputs and one value for all outputs. *)
weights_biases
:
float
list
list
;
(** All weights and biases of NNet model. *)
}
}
[
@@
deriving
show
{
with_path
=
false
}]
(** NNet model metadata. *)
(** NNet model metadata. *)
val
parse
_metadata
:
string
->
(
t
,
string
)
Result
.
t
val
parse
:
string
->
(
t
,
string
)
Result
.
t
(** Parse an NNet file
for metadata
. *)
(** Parse an NNet file. *)
This diff is collapsed.
Click to expand it.
src/dune
+
2
−
2
View file @
377c194e
...
@@ -2,6 +2,6 @@
...
@@ -2,6 +2,6 @@
(name main)
(name main)
(public_name caisar)
(public_name caisar)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3)
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(package caisar)
(package caisar)
)
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment