Skip to content
Snippets Groups Projects
Commit fcb78c15 authored by François Bobot's avatar François Bobot
Browse files

Add NNet input language

parent 0a6e75c7
No related branches found
No related tags found
No related merge requests found
all:
dune build @install caisar.opam nnet.opam
test:
dune runtest
......@@ -4,7 +4,7 @@ version: "0.1"
synopsis: "Framework for neural network verification"
depends: [
"ocaml" {>= "4.10"}
"dune" {>= "2.7" & >= "2.7.1"}
"dune" {>= "2.9" & >= "2.9.0"}
"base" {>= "v0.14.0"}
"cmdliner" {>= "1.0.4"}
"fmt" {>= "0.8.9"}
......@@ -17,7 +17,7 @@ depends: [
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
["dune" "subst" "--root" "."] {dev}
[
"dune"
"build"
......@@ -25,8 +25,11 @@ build: [
name
"-j"
jobs
"--promote-install-files"
"false"
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
["dune" "install" "-p" name "--create-install-files" name]
]
(lang dune 2.7)
(lang dune 2.9)
(name caisar)
(version 0.1)
(using dune_site 0.1)
(using menhir 2.1)
(cram enable)
(generate_opam_files true)
......@@ -11,7 +14,7 @@
(synopsis "Framework for neural network verification")
(depends
(ocaml (>= 4.10))
(dune (>= 2.7.1))
(dune (>= 2.9.0))
(base (>= v0.14.0))
(cmdliner (>= 1.0.4))
(fmt (>= 0.8.9))
......@@ -22,6 +25,7 @@
(ppx_deriving_yojson (>= 3.6.1))
(csv (>= 2.4))
)
(sites (lib stdlib))
)
(package
......
......@@ -20,3 +20,14 @@ theory ACAS_XU_P1
goal P1: forall y0:real. y0 <= 3.9911256459
end
theory ACAS_XU_P1_black_box
use ACASXU_experimental_v2a_1_1.As_tuple as NNet
constant x0
axiom ...
...
goal G: let y0,...,y4 = NNet.as_tuple(x0,...,x4) in ... <= y0 <= ... /\ ...
end
......@@ -132,10 +132,9 @@ let handle_nnet_weights_and_biases in_channel =
(* Retrieves [filename] NNet model metadata and weights wrt NNet format
specification (see https://github.com/sisl/NNet for details). *)
let parse filename =
let parse_cin filename in_channel =
let open Result in
try
let in_channel = Stdlib.open_in filename in
skip_nnet_header filename in_channel >>= fun () ->
let in_channel = Csv.of_channel in_channel in
handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
......@@ -164,3 +163,9 @@ let parse filename =
| Csv.Failure (_nrecord, _nfield, msg) -> Error msg
| Sys_error s -> Error s
| Failure msg -> Error (Format.sprintf "Unexpected error: %s." msg)
let parse filename =
let in_channel = Stdlib.open_in filename in
Caml.Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_cin filename in_channel)
......@@ -22,3 +22,6 @@ type t = private {
val parse : string -> (t, string) Result.t
(** Parse an NNet file. *)
val parse_cin : string -> in_channel -> (t, string) Result.t
(** Parse an NNet file. *)
......@@ -4,12 +4,12 @@ version: "0.1"
synopsis: "NNet parser"
depends: [
"ocaml" {>= "4.10"}
"dune" {>= "2.7" & >= "2.7.1"}
"dune" {>= "2.9" & >= "2.7.1"}
"base" {>= "v0.14.0"}
"odoc" {with-doc}
]
build: [
["dune" "subst"] {dev}
["dune" "subst" "--root" "."] {dev}
[
"dune"
"build"
......@@ -17,8 +17,11 @@ build: [
name
"-j"
jobs
"--promote-install-files"
"false"
"@install"
"@runtest" {with-test}
"@doc" {with-doc}
]
["dune" "install" "-p" name "--create-install-files" name]
]
(executable
(name main)
(public_name caisar)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3)
(libraries menhirLib yojson cmdliner logs logs.cli logs.fmt fmt.tty base unix str ppx_deriving_yojson.runtime nnet why3 dune-site)
(preprocess (pps ppx_deriving_yojson ppx_deriving.show ppx_deriving.ord ppx_deriving.eq))
(package caisar)
)
(generate_sites_module (module stdlib_path) (sites caisar))
open Base
(** Register Neural network languages *)
let nnet_parser env _ filename cin =
let nnet = Why3.Pmodule.read_module env [ "caisar" ] "NNet" in
let nnet_input_type =
Why3.Ty.ty_app
Why3.Theory.(ns_find_ts nnet.mod_theory.th_export [ "input_type" ])
[]
in
let header = Nnet.parse_cin filename cin in
match header with
| Error s -> Why3.Loc.errorm "%s" s
| Ok header ->
let id_as_tuple = Why3.Ident.id_fresh "As_tuple" in
let th_uc = Why3.Pmodule.create_module env id_as_tuple in
let th_uc = Why3.Pmodule.use_export th_uc nnet in
let open Why3 in
let ls_out =
Term.create_fsymbol (Ident.id_fresh "out")
(List.init header.n_inputs ~f:(fun _ -> nnet_input_type))
(Why3.Ty.ty_tuple
(List.init header.n_outputs ~f:(fun _ -> nnet_input_type)))
in
let th_uc =
Why3.Pmodule.add_pdecl ~vc:false th_uc
(Pdecl.create_pure_decl @@ Decl.create_param_decl ls_out)
in
Why3.Wstdlib.Mstr.singleton "AsTuple" (Pmodule.close_module th_uc)
let register () =
Why3.Env.(
register_format ~desc:"NNet format (only RLU)" Why3.Pmodule.mlw_language
"NNet" [ "nnet" ] nnet_parser)
......@@ -86,6 +86,35 @@ let config_cmd =
$ const cmdname $ detect $ setup_logs)),
Term.info cmdname ~sdocs:Manpage.s_common_options ~envs ~exits ~doc ~man )
let prove_cmd =
let cmdname = "prove" in
let files =
let doc = "Files to prove" in
Arg.(value & pos_all string [] & info [] ~doc)
in
let format =
let doc = "format" in
Arg.(value & opt (some string) None & info [ "format" ] ~doc)
in
let loadpath =
let doc = "additional loadpath" in
Arg.(value & opt_all string [] & info [ "L" ] ~doc)
in
let doc = Format.sprintf "%s configuration." caisar in
let exits = Term.default_exits in
let man =
[
`S Manpage.s_description;
`P (Format.sprintf "Handle the configuration of %s." caisar);
]
in
( Term.(
ret
(const (fun format loadpath files ->
`Ok (List.iter ~f:(Prove.prove format loadpath) files))
$ format $ loadpath $ files)),
Term.info cmdname ~sdocs:Manpage.s_common_options ~exits ~doc ~man )
let default_cmd =
let doc = "Framework for neural networks property verification and more." in
let sdocs = Manpage.s_common_options in
......@@ -104,6 +133,6 @@ let default_cmd =
Term.info caisar ~version ~doc ~sdocs ~exits:Term.default_exits ~man )
let () =
match Term.(eval_choice default_cmd [ config_cmd ]) with
match Term.(eval_choice default_cmd [ config_cmd; prove_cmd ]) with
| `Error _ -> Caml.exit 1
| _ -> Caml.exit (if Logs.err_count () > 0 then 1 else 0)
open Base
let () = Language.register ()
let create_env loadpath =
let stdlib = Stdlib_path.Sites.stdlib in
let conf = Why3.Whyconf.init_config None in
Why3.Env.create_env
(loadpath @ stdlib @ Why3.Whyconf.loadpath (Why3.Whyconf.get_main conf))
let prove format loadpath file =
let env = create_env loadpath in
let _, m =
match file with
| "-" ->
( "stdin",
Why3.Env.read_channel ?format Why3.Env.base_language env "stdin"
Caml.stdin )
| fname ->
let mlw_files, _ =
Why3.Env.read_file ?format Why3.Env.base_language env fname
in
(fname, mlw_files)
in
Why3.Wstdlib.Mstr.iter
(fun _ th ->
let l = Why3.Task.split_theory th None None in
List.iter l ~f:(Fmt.pr "%a" Why3.Pretty.print_task))
m
theory NNet
type input_type = int
end
(install
(section (site (caisar stdlib)))
(files caisar.mlw)
(package caisar))
This diff is collapsed.
(cram
(deps
(package caisar)
TestNetwork.nnet
))
Test help
$ caisar --version
0.0
Test help
$ caisar prove -L . --format whyml - <<EOF
> theory T
> use TestNetwork.AsTuple
> use int.Int
> use caisar.NNet
>
> goal G: forall x1 x2 x3 x4 x5.
> let (y1,_,_,_,_) = out x1 x2 x3 x4 x5 in
> 0 < y1 < 10
> end
> EOF
theory Task
type int
type real
type string
predicate (=) 'a 'a
(* use why3.BuiltIn.BuiltIn *)
type bool =
| True
| False
(* use why3.Bool.Bool *)
type tuple0 =
| Tuple0
(* use why3.Tuple0.Tuple01 *)
type unit = unit
(* use why3.Unit.Unit *)
type input_type = int
(* use caisar.NNet *)
type tuple5 'a 'a1 'a2 'a3 'a4 =
| Tuple5 'a 'a1 'a2 'a3 'a4
(* use why3.Tuple5.Tuple51 *)
function out int int int int int : (int, int, int, int, int)
(* use As_tuple *)
constant zero : int = 0
constant one : int = 1
function (-_) int : int
function (+) int int : int
function ( * ) int int : int
predicate (<) int int
function (-) (x:int) (y:int) : int = x + (- y)
predicate (>) (x:int) (y:int) = y < x
predicate (<=) (x:int) (y:int) = x < y \/ x = y
predicate (>=) (x:int) (y:int) = y <= x
Assoc : forall x:int, y:int, z:int. ((x + y) + z) = (x + (y + z))
(* clone algebra.Assoc with type t = int, function op = (+),
prop Assoc1 = Assoc, *)
Unit_def_l : forall x:int. (zero + x) = x
Unit_def_r : forall x:int. (x + zero) = x
(* clone algebra.Monoid with type t1 = int, constant unit = zero,
function op1 = (+), prop Unit_def_r1 = Unit_def_r,
prop Unit_def_l1 = Unit_def_l, prop Assoc2 = Assoc, *)
Inv_def_l : forall x:int. ((- x) + x) = zero
Inv_def_r : forall x:int. (x + (- x)) = zero
(* clone algebra.Group with type t2 = int, function inv = (-_),
constant unit1 = zero, function op2 = (+), prop Inv_def_r1 = Inv_def_r,
prop Inv_def_l1 = Inv_def_l, prop Unit_def_r2 = Unit_def_r,
prop Unit_def_l2 = Unit_def_l, prop Assoc3 = Assoc, *)
Comm : forall x:int, y:int. (x + y) = (y + x)
(* clone algebra.Comm with type t3 = int, function op3 = (+),
prop Comm1 = Comm, *)
(* meta AC function (+) *)
(* clone algebra.CommutativeGroup with type t4 = int, function inv1 = (-_),
constant unit2 = zero, function op4 = (+), prop Comm2 = Comm,
prop Inv_def_r2 = Inv_def_r, prop Inv_def_l2 = Inv_def_l,
prop Unit_def_r3 = Unit_def_r, prop Unit_def_l3 = Unit_def_l,
prop Assoc4 = Assoc, *)
Assoc5 : forall x:int, y:int, z:int. ((x * y) * z) = (x * (y * z))
(* clone algebra.Assoc with type t = int, function op = ( * ),
prop Assoc1 = Assoc5, *)
Mul_distr_l :
forall x:int, y:int, z:int. (x * (y + z)) = ((x * y) + (x * z))
Mul_distr_r :
forall x:int, y:int, z:int. ((y + z) * x) = ((y * x) + (z * x))
(* clone algebra.Ring with type t5 = int, function ( *') = ( * ),
function (-'_) = (-_), function (+') = (+), constant zero1 = zero,
prop Mul_distr_r1 = Mul_distr_r, prop Mul_distr_l1 = Mul_distr_l,
prop Assoc6 = Assoc5, prop Comm3 = Comm, prop Inv_def_r3 = Inv_def_r,
prop Inv_def_l3 = Inv_def_l, prop Unit_def_r4 = Unit_def_r,
prop Unit_def_l4 = Unit_def_l, prop Assoc7 = Assoc, *)
Comm4 : forall x:int, y:int. (x * y) = (y * x)
(* clone algebra.Comm with type t3 = int, function op3 = ( * ),
prop Comm1 = Comm4, *)
(* meta AC function ( * ) *)
(* clone algebra.CommutativeRing with type t6 = int,
function ( *'') = ( * ), function (-''_) = (-_), function (+'') = (+),
constant zero2 = zero, prop Comm5 = Comm4,
prop Mul_distr_r2 = Mul_distr_r, prop Mul_distr_l2 = Mul_distr_l,
prop Assoc8 = Assoc5, prop Comm6 = Comm, prop Inv_def_r4 = Inv_def_r,
prop Inv_def_l4 = Inv_def_l, prop Unit_def_r5 = Unit_def_r,
prop Unit_def_l5 = Unit_def_l, prop Assoc9 = Assoc, *)
Unitary : forall x:int. (one * x) = x
NonTrivialRing : not zero = one
(* clone algebra.UnitaryCommutativeRing with type t7 = int,
constant one1 = one, function ( *''') = ( * ), function (-'''_) = (-_),
function (+''') = (+), constant zero3 = zero,
prop NonTrivialRing1 = NonTrivialRing, prop Unitary1 = Unitary,
prop Comm7 = Comm4, prop Mul_distr_r3 = Mul_distr_r,
prop Mul_distr_l3 = Mul_distr_l, prop Assoc10 = Assoc5,
prop Comm8 = Comm, prop Inv_def_r5 = Inv_def_r,
prop Inv_def_l5 = Inv_def_l, prop Unit_def_r6 = Unit_def_r,
prop Unit_def_l6 = Unit_def_l, prop Assoc11 = Assoc, *)
(* clone relations.EndoRelation with type t8 = int, predicate rel = (<=),
*)
Refl : forall x:int. x <= x
(* clone relations.Reflexive with type t9 = int, predicate rel1 = (<=),
prop Refl1 = Refl, *)
(* clone relations.EndoRelation with type t8 = int, predicate rel = (<=),
*)
Trans : forall x:int, y:int, z:int. x <= y -> y <= z -> x <= z
(* clone relations.Transitive with type t10 = int, predicate rel2 = (<=),
prop Trans1 = Trans, *)
(* clone relations.PreOrder with type t11 = int, predicate rel3 = (<=),
prop Trans2 = Trans, prop Refl2 = Refl, *)
(* clone relations.EndoRelation with type t8 = int, predicate rel = (<=),
*)
Antisymm : forall x:int, y:int. x <= y -> y <= x -> x = y
(* clone relations.Antisymmetric with type t12 = int,
predicate rel4 = (<=), prop Antisymm1 = Antisymm, *)
(* clone relations.PartialOrder with type t13 = int, predicate rel5 = (<=),
prop Antisymm2 = Antisymm, prop Trans3 = Trans, prop Refl3 = Refl, *)
(* clone relations.EndoRelation with type t8 = int, predicate rel = (<=),
*)
Total : forall x:int, y:int. x <= y \/ y <= x
(* clone relations.Total with type t14 = int, predicate rel6 = (<=),
prop Total1 = Total, *)
(* clone relations.TotalOrder with type t15 = int, predicate rel7 = (<=),
prop Total2 = Total, prop Antisymm3 = Antisymm, prop Trans4 = Trans,
prop Refl4 = Refl, *)
ZeroLessOne : zero <= one
CompatOrderAdd : forall x:int, y:int, z:int. x <= y -> (x + z) <= (y + z)
CompatOrderMult :
forall x:int, y:int, z:int. x <= y -> zero <= z -> (x * z) <= (y * z)
(* clone algebra.OrderedUnitaryCommutativeRing with type t16 = int,
predicate (<=') = (<=), constant one2 = one, function ( *'''') = ( * ),
function (-''''_) = (-_), function (+'''') = (+), constant zero4 = zero,
prop CompatOrderMult1 = CompatOrderMult,
prop CompatOrderAdd1 = CompatOrderAdd, prop ZeroLessOne1 = ZeroLessOne,
prop Total3 = Total, prop Antisymm4 = Antisymm, prop Trans5 = Trans,
prop Refl5 = Refl, prop NonTrivialRing2 = NonTrivialRing,
prop Unitary2 = Unitary, prop Comm9 = Comm4,
prop Mul_distr_r4 = Mul_distr_r, prop Mul_distr_l4 = Mul_distr_l,
prop Assoc12 = Assoc5, prop Comm10 = Comm, prop Inv_def_r6 = Inv_def_r,
prop Inv_def_l6 = Inv_def_l, prop Unit_def_r7 = Unit_def_r,
prop Unit_def_l7 = Unit_def_l, prop Assoc13 = Assoc, *)
(* use int.Int *)
goal G :
forall x1:int, x2:int, x3:int, x4:int, x5:int.
match out x1 x2 x3 x4 x5 with
| y1, _, _, _, _ -> 0 < y1 /\ y1 < 10
end
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment