Skip to content
Snippets Groups Projects
Commit 2fc02270 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

[SAVer] Added SVM-specific OVO format parser.

parent 63618e75
No related branches found
No related tags found
No related merge requests found
...@@ -50,6 +50,16 @@ ...@@ -50,6 +50,16 @@
) )
) )
(package
(name ovo)
(synopsis "OVO parser")
(depends
(ocaml (>= 4.10))
(dune (>= 2.9.1))
(base (>= v0.14.0))
)
)
(package (package
(name onnx) (name onnx)
(synopsis "ONNX parser") (synopsis "ONNX parser")
......
(library
(name ovo)
(public_name ovo)
(libraries base csv)
(synopsis "OVO SVM format parser"))
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(**************************************************************************)
open Base
module Format = Caml.Format
module Sys = Caml.Sys
module Filename = Caml.Filename
module Fun = Caml.Fun
type t = { n_inputs : int; n_outputs : int }
(* OVO format handling. *)
let ovo_format_error s =
Error (Format.sprintf "OVO format error: %s condition not satisfied." s)
(* Parse a single OVO format line: split line wrt CSV format, and convert each
string into a number by means of converter [f]. *)
let handle_ovo_line ~f in_channel =
List.filter_map
~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
(Csv.next in_channel)
(* Skip the header part, ie comments, of the OVO format. *)
let skip_ovo_header filename in_channel =
let exception End_of_header in
let pos_in = ref (Stdlib.pos_in in_channel) in
try
while true do
let line = Stdlib.input_line in_channel in
if not (Str.string_match (Str.regexp "//") line 0)
then raise End_of_header
else pos_in := Stdlib.pos_in in_channel
done;
assert false
with
| End_of_header ->
(* At this point we have read one line past the header part: seek back. *)
Stdlib.seek_in in_channel !pos_in;
Ok ()
| End_of_file ->
Error (Format.sprintf "OVO model not found in file `%s'." filename)
(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_ovo_basic_info in_channel =
match handle_ovo_line ~f:Int.of_string in_channel with
| [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
Ok (n_layers, n_inputs, n_outputs, max_layer_size)
| _ -> ovo_format_error "second"
| exception End_of_file -> ovo_format_error "second"
(* Skip unused flag. *)
let handle_ovo_unused_flag in_channel =
try
let _ = Csv.next in_channel in
Ok ()
with End_of_file -> ovo_format_error "forth"
(* Retrieves [filename] OVO model metadata and weights wrt OVO format
specification, which is described here:
https://github.com/abstract-machine-learning/saver#classifier-format. *)
let parse_in_channel filename in_channel =
let open Result in
try
skip_ovo_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in
handle_ovo_basic_info in_channel >>= fun (_, n_is, n_os, _) ->
handle_ovo_unused_flag in_channel >>= fun () ->
Csv.close_in in_channel;
Ok { n_inputs = n_is; n_outputs = n_os }
with
| Csv.Failure (_nrecord, _nfield, msg) -> Error msg
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
let parse filename =
let in_channel = Stdlib.open_in filename in
Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel filename in_channel)
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(**************************************************************************)
type t = private {
n_inputs : int; (** Number of inputs. *)
n_outputs : int; (** Number of outputs. *)
}
(** OVO model metadata. *)
val parse : string -> (t, string) Result.t
(** Parse an OVO file. *)
ovo.opam 0 → 100644
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "0.1"
synopsis: "OVO parser"
depends: [
"ocaml" {>= "4.10"}
"dune" {>= "2.9" & >= "2.9.1"}
"base" {>= "v0.14.0"}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"--promote-install-files=false"
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
["dune" "install" "-p" name "--create-install-files" name]
]
(executable (executable
(name main) (name main)
(public_name caisar) (public_name caisar)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet onnx why3 dune-site re) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet onnx ovo why3 dune-site re)
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(package caisar) (package caisar)
) )
......
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
open Why3 open Why3
open Base open Base
(* -- Support for the NNet and ONNX neural network formats. *) (* -- Support for the NNet and ONNX neural network formats, as well as SVM under
the .ovo format *)
type nnshape = { type nnshape = {
nb_inputs : int; nb_inputs : int;
...@@ -42,6 +43,30 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env = ...@@ -42,6 +43,30 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env =
in in
Wstdlib.Mstr.singleton "NNasTuple" (Pmodule.close_module th_uc) Wstdlib.Mstr.singleton "NNasTuple" (Pmodule.close_module th_uc)
let register_asarray nb_inputs nb_outputs filename env =
let open Why3 in
let net = Pmodule.read_module env [ "caisar" ] "IOShape" in
let ioshape_input_type =
Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) []
in
let id_as_array = Ident.id_fresh "AsArray" in
let th_uc = Pmodule.create_module env id_as_array in
let th_uc = Pmodule.use_export th_uc net in
let ls_svm_apply =
let f _ = ioshape_input_type in
Term.create_fsymbol
(Ident.id_fresh "svm_apply")
(List.init nb_inputs ~f)
(Ty.ty_tuple (List.init nb_outputs ~f))
in
Why3.Term.Hls.add loaded_nets ls_svm_apply
{ filename; nb_inputs; nb_outputs; ty_data = ioshape_input_type };
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_svm_apply))
in
Wstdlib.Mstr.singleton "AsArray" (Pmodule.close_module th_uc)
let nnet_parser env _ filename _ = let nnet_parser env _ filename _ =
let model = Nnet.parse filename in let model = Nnet.parse filename in
match model with match model with
...@@ -54,6 +79,13 @@ let onnx_parser env _ filename _ = ...@@ -54,6 +79,13 @@ let onnx_parser env _ filename _ =
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env | Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env
let ovo_parser env _ filename _ =
let open Why3 in
let model = Ovo.parse filename in
match model with
| Error s -> Loc.errorm "%s" s
| Ok model -> register_asarray model.n_inputs model.n_outputs filename env
let register_nnet_support () = let register_nnet_support () =
Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
"NNet" [ "nnet" ] nnet_parser "NNet" [ "nnet" ] nnet_parser
...@@ -61,3 +93,7 @@ let register_nnet_support () = ...@@ -61,3 +93,7 @@ let register_nnet_support () =
let register_onnx_support () = let register_onnx_support () =
Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ] Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX" [ "onnx" ]
onnx_parser onnx_parser
let register_ovo_support () =
Env.register_format ~desc:"OVO format" Pmodule.mlw_language "OVO" [ "ovo" ]
ovo_parser
...@@ -21,3 +21,6 @@ val register_nnet_support : unit -> unit ...@@ -21,3 +21,6 @@ val register_nnet_support : unit -> unit
val register_onnx_support : unit -> unit val register_onnx_support : unit -> unit
(** Register ONNX parser. *) (** Register ONNX parser. *)
val register_ovo_support : unit -> unit
(** Register OVO parser. *)
...@@ -25,7 +25,8 @@ end ...@@ -25,7 +25,8 @@ end
let () = let () =
Language.register_nnet_support (); Language.register_nnet_support ();
Language.register_onnx_support () Language.register_onnx_support ();
Language.register_ovo_support ()
let create_env loadpath = let create_env loadpath =
let config = Autodetect.autodetection ~debug:true () in let config = Autodetect.autodetection ~debug:true () in
......
// Generated using the base example shown at https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html
ovo 2 2
rbf auto
2 2
1 2
-0.6592196433465484 -0.47786540839731506 0.6612929735049311 0.4757920782389323
-0.6324555320336759 -1.0
-1.2649110640673518 -1.0
0.6324555320336759 1.0
1.2649110640673518 1.0
-2.866041953212284e-05
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment