Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
caisar
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
pub
caisar
Commits
84025c4b
Commit
84025c4b
authored
2 years ago
by
Michele Alberti
Browse files
Options
Downloads
Patches
Plain Diff
[nnet] Admit parsing failure on non-mandatory information.
parent
3c965d10
No related branches found
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
lib/nnet/nnet.ml
+19
-12
19 additions, 12 deletions
lib/nnet/nnet.ml
lib/nnet/nnet.mli
+10
-6
10 additions, 6 deletions
lib/nnet/nnet.mli
lib/onnx/onnx.mli
+1
-2
1 addition, 2 deletions
lib/onnx/onnx.mli
src/language.ml
+1
-1
1 addition, 1 deletion
src/language.ml
with
31 additions
and
21 deletions
lib/nnet/nnet.ml
+
19
−
12
View file @
84025c4b
...
@@ -32,10 +32,10 @@ type t = {
...
@@ -32,10 +32,10 @@ type t = {
n_outputs
:
int
;
n_outputs
:
int
;
max_layer_size
:
int
;
max_layer_size
:
int
;
layer_sizes
:
int
list
;
layer_sizes
:
int
list
;
min_input_values
:
float
list
;
min_input_values
:
float
list
option
;
max_input_values
:
float
list
;
max_input_values
:
float
list
option
;
mean_values
:
float
list
*
float
;
mean_values
:
(
float
list
*
float
)
option
;
range_values
:
float
list
*
float
;
range_values
:
(
float
list
*
float
)
option
;
weights_biases
:
float
list
list
;
weights_biases
:
float
list
list
;
}
}
...
@@ -154,18 +154,25 @@ let handle_nnet_weights_and_biases in_channel =
...
@@ -154,18 +154,25 @@ let handle_nnet_weights_and_biases in_channel =
(* Retrieves [filename] NNet model metadata and weights wrt NNet format
(* Retrieves [filename] NNet model metadata and weights wrt NNet format
specification (see https://github.com/sisl/NNet for details). *)
specification (see https://github.com/sisl/NNet for details). *)
let
parse_in_channel
filename
in_channel
=
let
parse_in_channel
?
(
permissive
=
false
)
filename
in_channel
=
let
open
Result
in
let
open
Result
in
let
ok_opt
r
=
match
r
with
|
Ok
x
->
Ok
(
Some
x
)
|
Error
_
as
error
->
if
not
permissive
then
error
else
Ok
None
in
try
try
skip_nnet_header
filename
in_channel
>>=
fun
()
->
skip_nnet_header
filename
in_channel
>>=
fun
()
->
let
in_channel
=
Csv
.
of_channel
in_channel
in
let
in_channel
=
Csv
.
of_channel
in_channel
in
handle_nnet_basic_info
in_channel
>>=
fun
(
n_ls
,
n_is
,
n_os
,
max_l_size
)
->
handle_nnet_basic_info
in_channel
>>=
fun
(
n_ls
,
n_is
,
n_os
,
max_l_size
)
->
handle_nnet_layer_sizes
n_ls
in_channel
>>=
fun
layer_sizes
->
handle_nnet_layer_sizes
n_ls
in_channel
>>=
fun
layer_sizes
->
handle_nnet_unused_flag
in_channel
>>=
fun
()
->
handle_nnet_unused_flag
in_channel
>>=
fun
()
->
handle_nnet_min_input_values
n_is
in_channel
>>=
fun
min_input_values
->
ok_opt
(
handle_nnet_min_input_values
n_is
in_channel
)
handle_nnet_max_input_values
n_is
in_channel
>>=
fun
max_input_values
->
>>=
fun
min_input_values
->
handle_nnet_mean_values
n_is
in_channel
>>=
fun
mean_values
->
ok_opt
(
handle_nnet_max_input_values
n_is
in_channel
)
handle_nnet_range_values
n_is
in_channel
>>=
fun
range_values
->
>>=
fun
max_input_values
->
ok_opt
(
handle_nnet_mean_values
n_is
in_channel
)
>>=
fun
mean_values
->
ok_opt
(
handle_nnet_range_values
n_is
in_channel
)
>>=
fun
range_values
->
let
weights_biases
=
handle_nnet_weights_and_biases
in_channel
in
let
weights_biases
=
handle_nnet_weights_and_biases
in_channel
in
Csv
.
close_in
in_channel
;
Csv
.
close_in
in_channel
;
Ok
Ok
...
@@ -184,10 +191,10 @@ let parse_in_channel filename in_channel =
...
@@ -184,10 +191,10 @@ let parse_in_channel filename in_channel =
with
with
|
Csv
.
Failure
(
_nrecord
,
_nfield
,
msg
)
->
Error
msg
|
Csv
.
Failure
(
_nrecord
,
_nfield
,
msg
)
->
Error
msg
|
Sys_error
s
->
Error
s
|
Sys_error
s
->
Error
s
|
Failure
msg
->
Error
(
Format
.
sprintf
"Unexpected error: %s
.
"
msg
)
|
Failure
msg
->
Error
(
Format
.
sprintf
"Unexpected error: %s"
msg
)
let
parse
filename
=
let
parse
?
(
permissive
=
false
)
filename
=
let
in_channel
=
Stdlib
.
open_in
filename
in
let
in_channel
=
Stdlib
.
open_in
filename
in
Fun
.
protect
Fun
.
protect
~
finally
:
(
fun
()
->
Stdlib
.
close_in
in_channel
)
~
finally
:
(
fun
()
->
Stdlib
.
close_in
in_channel
)
(
fun
()
->
parse_in_channel
filename
in_channel
)
(
fun
()
->
parse_in_channel
~
permissive
filename
in_channel
)
This diff is collapsed.
Click to expand it.
lib/nnet/nnet.mli
+
10
−
6
View file @
84025c4b
...
@@ -26,15 +26,19 @@ type t = private {
...
@@ -26,15 +26,19 @@ type t = private {
n_outputs
:
int
;
(** Number of outputs. *)
n_outputs
:
int
;
(** Number of outputs. *)
max_layer_size
:
int
;
(** Maximum layer size. *)
max_layer_size
:
int
;
(** Maximum layer size. *)
layer_sizes
:
int
list
;
(** Size of each layer. *)
layer_sizes
:
int
list
;
(** Size of each layer. *)
min_input_values
:
float
list
;
(** Minimum values of inputs. *)
min_input_values
:
float
list
option
;
(** Minimum values of inputs. *)
max_input_values
:
float
list
;
(** Maximum values of inputs. *)
max_input_values
:
float
list
option
;
(** Maximum values of inputs. *)
mean_values
:
float
list
*
float
;
mean_values
:
(
float
list
*
float
)
option
;
(** Mean values of inputs and one value for all outputs. *)
(** Mean values of inputs and one value for all outputs. *)
range_values
:
float
list
*
float
;
range_values
:
(
float
list
*
float
)
option
;
(** Range values of inputs and one value for all outputs. *)
(** Range values of inputs and one value for all outputs. *)
weights_biases
:
float
list
list
;
(** All weights and biases of NNet model. *)
weights_biases
:
float
list
list
;
(** All weights and biases of NNet model. *)
}
}
(** NNet model metadata. *)
(** NNet model metadata. *)
val
parse
:
string
->
(
t
,
string
)
Result
.
t
val
parse
:
?
permissive
:
bool
->
string
->
(
t
,
string
)
Result
.
t
(** Parse an NNet file. *)
(** Parse an NNet file.
@param permissive
[false] by default. When set, parsing does not fail on non available
information, which are set to [None] instead. *)
This diff is collapsed.
Click to expand it.
lib/onnx/onnx.mli
+
1
−
2
View file @
84025c4b
...
@@ -28,7 +28,6 @@ type t = private {
...
@@ -28,7 +28,6 @@ type t = private {
}
}
(** ONNX model metadata. *)
(** ONNX model metadata. *)
val
parse
:
string
->
(
t
*
G
.
t
,
string
)
Result
.
t
(** Parse an ONNX file to get metadata for CAISAR as well as its inner
(** Parse an ONNX file to get metadata for CAISAR as well as its inner
intermediate representation for the network. *)
intermediate representation for the network. *)
val
parse
:
string
->
(
t
*
G
.
t
,
string
)
Result
.
t
This diff is collapsed.
Click to expand it.
src/language.ml
+
1
−
1
View file @
84025c4b
...
@@ -84,7 +84,7 @@ let register_svm_as_array nb_inputs nb_classes filename env =
...
@@ -84,7 +84,7 @@ let register_svm_as_array nb_inputs nb_classes filename env =
Wstdlib
.
Mstr
.
singleton
"SVMasArray"
(
Pmodule
.
close_module
th_uc
)
Wstdlib
.
Mstr
.
singleton
"SVMasArray"
(
Pmodule
.
close_module
th_uc
)
let
nnet_parser
env
_
filename
_
=
let
nnet_parser
env
_
filename
_
=
let
model
=
Nnet
.
parse
filename
in
let
model
=
Nnet
.
parse
~
permissive
:
true
filename
in
match
model
with
match
model
with
|
Error
s
->
Loc
.
errorm
"%s"
s
|
Error
s
->
Loc
.
errorm
"%s"
s
|
Ok
model
->
|
Ok
model
->
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment