Skip to content
Snippets Groups Projects
Commit 39291caf authored by Julien Girard-Satabin's avatar Julien Girard-Satabin
Browse files

Merge branch 'onnx_parser' into 'feature/why3'

Onnx parser

See merge request laiser/caisar!8
parents 09f3df5c 07df3fc5
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ tests: ...@@ -12,6 +12,7 @@ tests:
- if [ ! -d _opam ]; then echo "no local switch in the CI cache, we setup a new switch"; opam switch create --yes --no-install . ocaml-base-compiler.4.11.1; fi - if [ ! -d _opam ]; then echo "no local switch in the CI cache, we setup a new switch"; opam switch create --yes --no-install . ocaml-base-compiler.4.11.1; fi
- eval $(opam env) - eval $(opam env)
- sudo apt-get update - sudo apt-get update
- sudo apt install -y protobuf-compiler
- opam repository add remote https://opam.ocaml.org - opam repository add remote https://opam.ocaml.org
- opam depext --yes ocplib-endian base fmt alt-ergo.2.4.0 - opam depext --yes ocplib-endian base fmt alt-ergo.2.4.0
- opam install . --deps-only --with-test --yes - opam install . --deps-only --with-test --yes
......
all: all:
dune build --root=. @install caisar.opam nnet.opam dune build --root=. @install caisar.opam nnet.opam onnx.opam
test: test:
dune runtest --root=. dune runtest --root=.
......
...@@ -4,10 +4,14 @@ version: "0.1" ...@@ -4,10 +4,14 @@ version: "0.1"
synopsis: "Framework for neural network verification" synopsis: "Framework for neural network verification"
depends: [ depends: [
"ocaml" {>= "4.10"} "ocaml" {>= "4.10"}
"dune" {>= "2.9" & >= "2.9.0"}
"dune-site" {>= "2.9.0"} "dune-site" {>= "2.9.0"}
"why3" "piqi" {>= "0.7.6"}
"piqilib" {>= "0.6.14"}
"zarith" {>= "1.7"}
"ocplib-endian" {>= "1.0"}
"dune" {>= "2.9" & >= "2.7.1"}
"base" {>= "v0.14.0"} "base" {>= "v0.14.0"}
"stdio" {>= "v0.14.0"}
"cmdliner" {>= "1.0.4"} "cmdliner" {>= "1.0.4"}
"fmt" {>= "0.8.9"} "fmt" {>= "0.8.9"}
"logs" {>= "0.7.0"} "logs" {>= "0.7.0"}
...@@ -18,10 +22,11 @@ depends: [ ...@@ -18,10 +22,11 @@ depends: [
"csv" {>= "2.4"} "csv" {>= "2.4"}
"why3" {>= "1.4"} "why3" {>= "1.4"}
"re" "re"
"onnx"
"odoc" {with-doc} "odoc" {with-doc}
] ]
build: [ build: [
["dune" "subst" "--root" "."] {dev} ["dune" "subst"] {dev}
[ [
"dune" "dune"
"build" "build"
...@@ -29,8 +34,7 @@ build: [ ...@@ -29,8 +34,7 @@ build: [
name name
"-j" "-j"
jobs jobs
"--promote-install-files" "--promote-install-files=false"
"false"
"@install" "@install"
"@runtest" {with-test} "@runtest" {with-test}
"@doc" {with-doc} "@doc" {with-doc}
......
...@@ -14,10 +14,14 @@ ...@@ -14,10 +14,14 @@
(synopsis "Framework for neural network verification") (synopsis "Framework for neural network verification")
(depends (depends
(ocaml (>= 4.10)) (ocaml (>= 4.10))
(dune (>= 2.9.0))
(dune-site (>= 2.9.0)) (dune-site (>= 2.9.0))
why3 (piqi (>= 0.7.6))
(piqilib (>= 0.6.14))
(zarith (>= 1.7))
(ocplib-endian (>= 1.0))
(dune (>= 2.7.1))
(base (>= v0.14.0)) (base (>= v0.14.0))
(stdio (>= v0.14.0))
(cmdliner (>= 1.0.4)) (cmdliner (>= 1.0.4))
(fmt (>= 0.8.9)) (fmt (>= 0.8.9))
(logs (>= 0.7.0)) (logs (>= 0.7.0))
...@@ -28,6 +32,7 @@ ...@@ -28,6 +32,7 @@
(csv (>= 2.4)) (csv (>= 2.4))
(why3 (>= 1.4)) (why3 (>= 1.4))
re re
onnx
) )
(sites (sites
(share stdlib) (share stdlib)
...@@ -40,7 +45,18 @@ ...@@ -40,7 +45,18 @@
(synopsis "NNet parser") (synopsis "NNet parser")
(depends (depends
(ocaml (>= 4.10)) (ocaml (>= 4.10))
(dune (>= 2.7.1)) (dune (>= 2.9.1))
(base (>= v0.14.0))
)
)
(package
(name onnx)
(synopsis "ONNX parser")
(depends
(ocaml (>= 4.10))
(dune (>= 2.9.1))
(base (>= v0.14.0)) (base (>= v0.14.0))
(ocaml-protoc-plugin (= 4.2.0))
) )
) )
(library
(name onnx)
(public_name onnx)
(libraries base stdio piqirun.pb ocaml-protoc-plugin)
(synopsis "ONNX parser"))
(rule
(targets onnx_protoc.ml)
(action
(run ./generate_onnx_interface.sh)
)
(deps
onnx_protoc.proto
generate_onnx_interface.sh)
)
#!/bin/sh
protoc --ocaml_out=. onnx_protoc.proto
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(**************************************************************************)
open Base
module Format = Caml.Format
module Fun = Caml.Fun
module Oproto = Onnx_protoc (* Autogenerated during compilation *)
module Oprotom = Oproto.Onnx.ModelProto
type t = {
n_inputs : int; (* Number of inputs. *)
n_outputs : int; (* Number of outputs. *)
}
(* ONNX format handling. *)
let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) =
match List.nth s 0 with
| Some { type' = Some { value = `Tensor_type { shape = Some v; _ }; _ }; _ }
->
v
| _ -> []
let flattened_dim (dim : Oproto.Onnx.TensorShapeProto.Dimension.t list) =
List.fold ~init:1 dim ~f:(fun acc x ->
match x.value with
| `Dim_value v -> acc * v
| `Dim_param _ -> acc
| `not_set -> acc)
let get_input_output_dim (model : Oprotom.t) =
let ins, outs =
match model.graph with
| Some g -> (Some g.input, Some g.output)
| None -> (None, None)
in
let input_shape, output_shape =
match (ins, outs) with
| Some i, Some o -> (get_nested_dims i, get_nested_dims o)
| _ -> ([], [])
in
(* TODO: here we only get the flattened dimension of inputs and outputs, but
more interesting parsing could be done later on. *)
let input_flat_dim = flattened_dim input_shape in
let output_flat_dim = flattened_dim output_shape in
(input_flat_dim, output_flat_dim)
let parse_in_channel in_channel =
let open Result in
try
let buf = Stdio.In_channel.input_all in_channel in
let reader = Ocaml_protoc_plugin.Reader.create buf in
match Oprotom.from_proto reader with
| Ok r ->
let n_inputs, n_outputs = get_input_output_dim r in
Ok { n_inputs; n_outputs }
| _ -> Error "Error parsing protobuf"
with
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
let parse filename =
let in_channel = Stdlib.open_in filename in
Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel in_channel)
(**************************************************************************)
(* *)
(* This file is part of CAISAR. *)
(* *)
(**************************************************************************)
type t = private {
n_inputs : int; (** Number of inputs. *)
n_outputs : int; (** Number of outputs. *)
}
(** ONNX model metadata. *)
val parse : string -> (t, string) Result.t
(** Parse an ONNX file. *)
This diff is collapsed.
...@@ -4,12 +4,12 @@ version: "0.1" ...@@ -4,12 +4,12 @@ version: "0.1"
synopsis: "NNet parser" synopsis: "NNet parser"
depends: [ depends: [
"ocaml" {>= "4.10"} "ocaml" {>= "4.10"}
"dune" {>= "2.9" & >= "2.7.1"} "dune" {>= "2.9" & >= "2.9.1"}
"base" {>= "v0.14.0"} "base" {>= "v0.14.0"}
"odoc" {with-doc} "odoc" {with-doc}
] ]
build: [ build: [
["dune" "subst" "--root" "."] {dev} ["dune" "subst"] {dev}
[ [
"dune" "dune"
"build" "build"
...@@ -17,8 +17,7 @@ build: [ ...@@ -17,8 +17,7 @@ build: [
name name
"-j" "-j"
jobs jobs
"--promote-install-files" "--promote-install-files=false"
"false"
"@install" "@install"
"@runtest" {with-test} "@runtest" {with-test}
"@doc" {with-doc} "@doc" {with-doc}
......
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "0.1"
synopsis: "ONNX parser"
depends: [
"ocaml" {>= "4.10"}
"dune" {>= "2.9" & >= "2.9.1"}
"base" {>= "v0.14.0"}
"ocaml-protoc-plugin" {= "4.2.0"}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
[
"dune"
"build"
"-p"
name
"-j"
jobs
"--promote-install-files=false"
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
["dune" "install" "-p" name "--create-install-files" name]
]
(executable (executable
(name main) (name main)
(public_name caisar) (public_name caisar)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3 dune-site re) (libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet onnx why3 dune-site re)
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq)) (preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(package caisar) (package caisar)
) )
......
...@@ -6,55 +6,63 @@ ...@@ -6,55 +6,63 @@
open Base open Base
(* -- Support for the NNet neural network format. *) (* -- Support for the NNet and ONNX neural network formats. *)
type nnet = { type ioshape = {
nb_inputs : int; nb_inputs : int;
nb_outputs : int; nb_outputs : int;
ty_data : Why3.Ty.ty; ty_data : Why3.Ty.ty;
filename : string; filename : string;
} }
let loaded_nnets = Why3.Term.Hls.create 10 let loaded_nets = Why3.Term.Hls.create 10
let lookup_loaded_nnets = Why3.Term.Hls.find_opt loaded_nnets let lookup_loaded_nets = Why3.Term.Hls.find_opt loaded_nets
let register_astuple nb_inputs nb_outputs filename env =
let open Why3 in
let net = Pmodule.read_module env [ "caisar" ] "IOShape" in
let ioshape_input_type =
Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) []
in
let id_as_tuple = Ident.id_fresh "AsTuple" in
let th_uc = Pmodule.create_module env id_as_tuple in
let th_uc = Pmodule.use_export th_uc net in
let ls_net_apply =
let f _ = ioshape_input_type in
Term.create_fsymbol
(Ident.id_fresh "net_apply")
(List.init nb_inputs ~f)
(Ty.ty_tuple (List.init nb_outputs ~f))
in
Why3.Term.Hls.add loaded_nets ls_net_apply
{ filename; nb_inputs; nb_outputs; ty_data = ioshape_input_type };
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply))
in
Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc)
let nnet_parser env _ filename _ = let nnet_parser env _ filename _ =
let open Why3 in let open Why3 in
let header = Nnet.parse filename in let model = Nnet.parse filename in
match header with match model with
| Error s -> Loc.errorm "%s" s | Error s -> Loc.errorm "%s" s
| Ok header -> | Ok model -> register_astuple model.n_inputs model.n_outputs filename env
let nnet = Pmodule.read_module env [ "caisar" ] "NNet" in
let nnet_input_type = let onnx_parser env _ filename _ =
Ty.ty_app let open Why3 in
Theory.(ns_find_ts nnet.mod_theory.th_export [ "input_type" ]) let model = Onnx.parse filename in
[] match model with
in | Error s -> Loc.errorm "%s" s
let id_as_tuple = Ident.id_fresh "AsTuple" in | Ok model -> register_astuple model.n_inputs model.n_outputs filename env
let th_uc = Pmodule.create_module env id_as_tuple in
let th_uc = Pmodule.use_export th_uc nnet in
let ls_nnet_apply =
let f _ = nnet_input_type in
Term.create_fsymbol
(Ident.id_fresh "nnet_apply")
(List.init header.n_inputs ~f)
(Ty.ty_tuple (List.init header.n_outputs ~f))
in
Why3.Term.Hls.add loaded_nnets ls_nnet_apply
{
filename;
nb_inputs = header.n_inputs;
nb_outputs = header.n_outputs;
ty_data = nnet_input_type;
};
let th_uc =
Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl (Decl.create_param_decl ls_nnet_apply))
in
Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc)
let register_nnet_support () = let register_nnet_support () =
Why3.( Why3.(
Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language Env.register_format ~desc:"NNet format (ReLU only)" Pmodule.mlw_language
"NNet" [ "nnet" ] nnet_parser) "NNet" [ "nnet" ] nnet_parser)
let register_onnx_support () =
Why3.(
Env.register_format ~desc:"ONNX format" Pmodule.mlw_language "ONNX"
[ "onnx" ] onnx_parser)
...@@ -4,15 +4,18 @@ ...@@ -4,15 +4,18 @@
(* *) (* *)
(**************************************************************************) (**************************************************************************)
type nnet = { type ioshape = {
nb_inputs : int; nb_inputs : int;
nb_outputs : int; nb_outputs : int;
ty_data : Why3.Ty.ty; ty_data : Why3.Ty.ty;
filename : string; filename : string;
} }
val lookup_loaded_nnets : Why3.Term.lsymbol -> nnet option val lookup_loaded_nets : Why3.Term.lsymbol -> ioshape option
(** @return the filename of a nnet Why3 representation. *) (** @return the filename of a nnet Why3 representation. *)
val register_nnet_support : unit -> unit val register_nnet_support : unit -> unit
(** Register nnet parser. *) (** Register NNet parser. *)
val register_onnx_support : unit -> unit
(** Register ONNX parser. *)
...@@ -30,7 +30,7 @@ let get_input_variables = ...@@ -30,7 +30,7 @@ let get_input_variables =
let rec aux acc (term : Term.term) = let rec aux acc (term : Term.term) =
match term.t_node with match term.t_node with
| Term.Tapp (ls, args) -> ( | Term.Tapp (ls, args) -> (
match Language.lookup_loaded_nnets ls with match Language.lookup_loaded_nets ls with
| None -> acc | None -> acc
| Some _ -> | Some _ ->
let add i acc = function let add i acc = function
...@@ -52,7 +52,7 @@ let simplify_goal env input_variables = ...@@ -52,7 +52,7 @@ let simplify_goal env input_variables =
let rec aux meta hls (term : Term.term) = let rec aux meta hls (term : Term.term) =
match term.t_node with match term.t_node with
| Term.Tapp (ls, _) -> ( | Term.Tapp (ls, _) -> (
match Language.lookup_loaded_nnets ls with match Language.lookup_loaded_nets ls with
| None -> Term.t_map (aux meta hls) term | None -> Term.t_map (aux meta hls) term
| Some nnet -> | Some nnet ->
meta := nnet.filename :: !meta; meta := nnet.filename :: !meta;
......
...@@ -8,6 +8,7 @@ open Base ...@@ -8,6 +8,7 @@ open Base
module Filename = Caml.Filename module Filename = Caml.Filename
let () = Language.register_nnet_support () let () = Language.register_nnet_support ()
let () = Language.register_onnx_support ()
let create_env loadpath = let create_env loadpath =
let config = Autodetection.autodetect ~debug:true () in let config = Autodetection.autodetect ~debug:true () in
......
theory NNet theory IOShape
use ieee_float.Float64 use ieee_float.Float64
type input_type = t type input_type = t
end end
pytorch1.8:
'
actual_input
126MatMul_0"MatMul

6
fc1.bias7Add_1"Add

78Relu_2"Relu

8
1310MatMul_3"MatMul
)
10
fc2.bias actual_outputAdd_4"Addtorch-jit-export*Bfc1.biasJE> *Bfc2.biasJ3?*$B12JcF>9
>v"9A*B13Jtyԕ>D##?aj>Z&
actual_input




b'
actual_output




B
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
(deps (deps
(package caisar) (package caisar)
TestNetwork.nnet TestNetwork.nnet
TestNetworkONNX.onnx
bin/pyrat.py bin/pyrat.py
bin/Marabou bin/Marabou
)) ))
...@@ -21,11 +21,11 @@ Test verify ...@@ -21,11 +21,11 @@ Test verify
> theory T > theory T
> use TestNetwork.AsTuple > use TestNetwork.AsTuple
> use ieee_float.Float64 > use ieee_float.Float64
> use caisar.NNet > use caisar.IOShape
> >
> goal G: forall x1 x2 x3 x4 x5. > goal G: forall x1 x2 x3 x4 x5.
> (0.0:t) .< x1 .< (0.5:t) -> > (0.0:t) .< x1 .< (0.5:t) ->
> let (y1,_,_,_,_) = nnet_apply x1 x2 x3 x4 x5 in > let (y1,_,_,_,_) = net_apply x1 x2 x3 x4 x5 in
> (0.0:t) .< y1 .< (0.5:t) > (0.0:t) .< y1 .< (0.5:t)
> >
> goal H: forall x1 x2 x3 x4 x5. > goal H: forall x1 x2 x3 x4 x5.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment