diff --git a/Makefile b/Makefile index de002d3287462d7ba94b0a04550b67137987647a..bb056159a9c48afe25209861e42249d533732c19 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ all: - dune build --root=. @install caisar.opam caisar-nnet.opam caisar-onnx.opam caisar-ovo.opam + dune build --root=. @install caisar.opam caisar-nnet.opam caisar-onnx.opam caisar-ovo.opam caisar-ir.opam install: dune install diff --git a/caisar-ir.opam b/caisar-ir.opam new file mode 100644 index 0000000000000000000000000000000000000000..6841ff21dcc2a4a9b8763fa0ac51e1992ff66a87 --- /dev/null +++ b/caisar-ir.opam @@ -0,0 +1,39 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +version: "0.1" +synopsis: "CAISAR's intermediate representation" +maintainer: [ + "LAISER team, Software Safety and Security Laboratory, CEA-List" +] +authors: ["LAISER team, Software Safety and Security Laboratory, CEA-List"] +license: "LGPL-2.1-only" +homepage: "https://git.frama-c.com/pub/caisar" +doc: "https://git.frama-c.com/pub/caisar" +bug-reports: "https://git.frama-c.com/pub/caisar/issues" +depends: [ + "ocaml" {>= "4.13"} + "dune" {>= "2.9" & >= "2.9.3"} + "base" {>= "v0.14.0"} + "ocaml-protoc-plugin" {= "4.2.0"} + "ocamlgraph" {>= "1.8.8"} + "ppx_inline_test" {>= "0.12.0"} + "ppx_deriving" {>= "4.4.1"} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "--promote-install-files=false" + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] + ["dune" "install" "-p" name "--create-install-files" name] +] +dev-repo: "git+https://git.frama-c.com/pub/caisar.git" diff --git a/caisar.opam b/caisar.opam index 6bf5ad2fd63d1614b91fcdf1549690b2da6eb413..60e13a6be66cd8713e6761e7332afea6635a3806 100644 --- a/caisar.opam +++ b/caisar.opam @@ -34,6 +34,7 @@ depends: [ "caisar-nnet" {= version} "caisar-ovo" {= version} "caisar-onnx" {= version} + "caisar-ir" {= version} "odoc" {with-doc} ] build: [ diff --git a/config/caisar-detection-data.conf b/config/caisar-detection-data.conf index 387167999f7e26fb68286e13cb48cfe3fd771f30..b3f1987360108c7c700bb9769d531bf3ee71713b 100644 --- a/config/caisar-detection-data.conf +++ b/config/caisar-detection-data.conf @@ -32,6 +32,17 @@ driver = "alt_ergo" editor = "altgr-ergo" use_at_auto_level = 1 +[ATP cvc5] +name = "CVC5" +exec = "cvc5" +version_switch = "--version 2>&1 | head -1" +version_regexp = "This is cvc5 version \\([0-9.]+\\)" +version_ok = "1.0.2" +command = "%e --stats-internal --tlimit=%t000 %f" +command_steps = "%e --stats-internal --steps-bound %S %f" +driver = "caisar_drivers/cvc5.drv" +use_at_auto_level = 1 + [ATP marabou] name = "Marabou" exec = "Marabou" diff --git a/config/drivers/cvc4_16.gen b/config/drivers/cvc4_16.gen new file mode 100644 index 0000000000000000000000000000000000000000..b5f83680ae25b2ba02baad0c4708468f7ae4e22c --- /dev/null +++ b/config/drivers/cvc4_16.gen @@ -0,0 +1,67 @@ +(** Why3 driver for CVC4 >= 1.6 (with floating point support) *) + +prelude "(set-info :smt-lib-version 2.6)" + +import "smt-libv2.gen" +printer "smtv2.6" +import "smt-libv2-bv.gen" +import "cvc4_bv.gen" +import "smt-libv2-floats.gen" +import "discrimination.gen" + +transformation "inline_trivial" +transformation "eliminate_builtin" +transformation "detect_polymorphism" +transformation "eliminate_definition_conditionally" +transformation "eliminate_inductive" +transformation "eliminate_algebraic_if_poly" +transformation "eliminate_literal" +transformation "eliminate_epsilon" +transformation "simplify_formula" + +(* Prepare for counter-example query: get rid of some quantifiers + (makes it possible to query model values of the variables in + premises) and introduce counter-example projections. Note: does + nothing if meta get_counterexmp is not set *) +transformation "prepare_for_counterexmp" + +transformation "discriminate_if_poly" +transformation "encoding_smt_if_poly" + +(** Error messages specific to CVC4 *) + +outofmemory "(error \".*out of memory\")" +outofmemory "CVC4 suffered a segfault" +outofmemory "CVC4::BVMinisat::OutOfMemoryException" +outofmemory "std::bad_alloc" +outofmemory "Cannot allocate memory" +timeout "interrupted by timeout" +steps "smt::SmtEngine::resourceUnitsUsed, \\([0-9]+.?[0-9]*\\)" 1 +(* +specific output message when CVC4 reaches its resource limit +*) +steplimitexceeded "unknown (RESOURCEOUT)" + + +(** Extra theories supported by CVC4 *) + +(* CVC4 division seems to be the Euclidean one, not the Computer one *) +theory int.EuclideanDivision + syntax function div "(div %1 %2)" + syntax function mod "(mod %1 %2)" + remove prop Mod_bound + remove prop Div_mod + remove prop Mod_1 + remove prop Div_1 +end + +(* +theory int.ComputerDivision + syntax function div "(div %1 %2)" + syntax function mod "(mod %1 %2)" + remove prop Mod_bound + remove prop Div_mod + remove prop Mod_1 + remove prop Div_1 +end +*) diff --git a/config/drivers/cvc4_bv.gen b/config/drivers/cvc4_bv.gen new file mode 100644 index 0000000000000000000000000000000000000000..028e8306173fca15bea41619018c05d6337e4651 --- /dev/null +++ b/config/drivers/cvc4_bv.gen @@ -0,0 +1,41 @@ +(* bitvector modules, is not in smt-libv2.gen since cvc4 and z3 don't + have the same name for the function to_int *) + +theory bv.BV_Gen + syntax function to_uint "(bv2nat %1)" +end + +theory bv.BV256 + (* mapping of_int to int2bv is disabled because it breaks proofs + in examples/bitcount, examples/esterel, + examples/isqrt_von_neumann, examples/rightmostbittrick, + examples/bitwalker *) + + (* syntax function of_int "((_ int2bv 256) %1)" *) + syntax function t'int "(bv2nat %1)" +end + +theory bv.BV128 + (* syntax function of_int "((_ int2bv 128) %1)" *) + syntax function t'int "(bv2nat %1)" +end + +theory bv.BV64 + (* syntax function of_int "((_ int2bv 64) %1)" *) + syntax function t'int "(bv2nat %1)" +end + +theory bv.BV32 + (* syntax function of_int "((_ int2bv 32) %1)" *) + syntax function t'int "(bv2nat %1)" +end + +theory bv.BV16 + (* syntax function of_int "((_ int2bv 16) %1)" *) + syntax function t'int "(bv2nat %1)" +end + +theory bv.BV8 + (* syntax function of_int "((_ int2bv 8) %1)" *) + syntax function t'int "(bv2nat %1)" +end diff --git a/config/drivers/cvc5.drv b/config/drivers/cvc5.drv new file mode 100644 index 0000000000000000000000000000000000000000..53b26dbb1055e155acff135004a00cea8beb3fb2 --- /dev/null +++ b/config/drivers/cvc5.drv @@ -0,0 +1,22 @@ +(** Why3 driver for CVC5 1.0.0 *) + +prelude ";; produced by cvc5.drv ;;" + +transformation "actual_net_apply" + +prelude "(set-logic ALL)" + +unknown "^(error \"Can't get-info :reason-unknown when the last result wasn't unknown!\")$" "(not unknown!)" + +outofmemory "cvc5 suffered a segfault" +outofmemory "cvc5::internal::Minisat::OutOfMemoryException" + +steps "resource::resourceUnitsUsed = \\([0-9]+\\)" 1 + +import "cvc4_16.gen" + +theory BuiltIn + + meta "supports_smt_get_info_unknown_reason" "" + +end diff --git a/config/drivers/discrimination.gen b/config/drivers/discrimination.gen new file mode 100644 index 0000000000000000000000000000000000000000..519eab685308821fc21591040a239183d3840624 --- /dev/null +++ b/config/drivers/discrimination.gen @@ -0,0 +1,7 @@ +theory BuiltIn + meta "select_inst_default" "local" + meta "select_lskept_default" "local" + meta "select_lsinst_default" "local" + meta "select_kept_default" "all" +end + diff --git a/config/drivers/smt-libv2-bv.gen b/config/drivers/smt-libv2-bv.gen new file mode 100644 index 0000000000000000000000000000000000000000..1d9b4a66ce1fb29f9d4fc1bf6a734865de434605 --- /dev/null +++ b/config/drivers/smt-libv2-bv.gen @@ -0,0 +1,297 @@ +(* Why3 driver for SMT-LIB2, common part of bit-vector theories *) + +prelude ";;; SMT-LIB2 driver: bit-vectors, common part" + +theory bv.BV_Gen + remove prop size_pos + remove prop nth_out_of_bound + remove prop Nth_zeros + remove prop Nth_ones + + syntax function bw_and "(bvand %1 %2)" + syntax function bw_or "(bvor %1 %2)" + syntax function bw_xor "(bvxor %1 %2)" + syntax function bw_not "(bvnot %1)" + + (* Warning: we should NOT remove all the axioms using "allprops" *) + + remove prop Nth_bw_and + remove prop Nth_bw_or + remove prop Nth_bw_xor + remove prop Nth_bw_not + + (** Shift operators *) + + remove prop Lsr_nth_low + remove prop Lsr_nth_high + remove prop lsr_zeros + remove prop Asr_nth_low + remove prop Asr_nth_high + remove prop asr_zeros + remove prop Lsl_nth_low + remove prop Lsl_nth_high + remove prop lsl_zeros + remove prop Nth_rotate_left + remove prop Nth_rotate_right + + (* Conversions from/to integers *) + + remove prop two_power_size_val + remove prop max_int_val + + (* function to_int - solver specific *) + (* function to_uint - solver specific *) + (* function of_int - solver specific *) + + remove prop to_uint_extensionality + remove prop to_int_extensionality + + remove prop to_uint_bounds + (*remove prop to_uint_of_int*) + remove prop to_uint_size_bv + remove prop to_uint_zeros + remove prop to_uint_ones + remove prop to_uint_one + + (* comparison operators *) + + syntax predicate ult "(bvult %1 %2)" + syntax predicate ule "(bvule %1 %2)" + syntax predicate ugt "(bvugt %1 %2)" + syntax predicate uge "(bvuge %1 %2)" + syntax predicate slt "(bvslt %1 %2)" + syntax predicate sle "(bvsle %1 %2)" + syntax predicate sgt "(bvsgt %1 %2)" + syntax predicate sge "(bvsge %1 %2)" + + (** Arithmetic operators *) + + syntax function add "(bvadd %1 %2)" + remove prop to_uint_add + remove prop to_uint_add_bounded + + syntax function sub "(bvsub %1 %2)" + remove prop to_uint_sub + remove prop to_uint_sub_bounded + + syntax function neg "(bvneg %1)" + remove prop to_uint_neg + + syntax function mul "(bvmul %1 %2)" + remove prop to_uint_mul + remove prop to_uint_mul_bounded + + syntax function udiv "(bvudiv %1 %2)" + remove prop to_uint_udiv + + syntax function urem "(bvurem %1 %2)" + remove prop to_uint_urem + + syntax function sdiv "(bvsdiv %1 %2)" + (*remove prop to_int_sdiv*) + + syntax function srem "(bvsrem %1 %2)" + (*remove prop to_int_srem*) + + (** Bitvector alternatives for shifts, rotations and nth *) + + syntax function lsr_bv "(bvlshr %1 %2)" + (* remove prop lsr_bv_is_lsr *) + remove prop to_uint_lsr + + syntax function asr_bv "(bvashr %1 %2)" + (* remove prop asr_bv_is_asr *) + + syntax function lsl_bv "(bvshl %1 %2)" + (* remove prop lsl_bv_is_lsl *) + + remove prop to_uint_lsl + + (** rotations *) + (* remove prop rotate_left_bv_is_rotate_left *) + (* remove prop rotate_right_bv_is_rotate_right *) + + (** nth_bv *) + + (* remove prop nth_bv_def *) + (* remove prop Nth_bv_is_nth *) + (* remove prop Nth_bv_is_nth2 *) + + remove prop Extensionality +end + +theory bv.BV256 + meta "literal:keep" type t + + syntax literal t "#x%64x" + syntax type t "(_ BitVec 256)" + + syntax function zeros "#x0000000000000000000000000000000000000000000000000000000000000000" + syntax function one "#x0000000000000000000000000000000000000000000000000000000000000001" + syntax function ones "#xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" + syntax function size_bv "(_ bv256 256)" + + syntax predicate is_signed_positive "(bvsge %1 (_ bv0 256))" + + syntax function rotate_left_bv "(bvor (bvshl %1 (bvurem %2 (_ bv256 256))) (bvlshr %1 (bvsub (_ bv256 256) (bvurem %2 (_ bv256 256)))))" + syntax function rotate_right_bv "(bvor (bvlshr %1 (bvurem %2 (_ bv256 256))) (bvshl %1 (bvsub (_ bv256 256) (bvurem %2 (_ bv256 256)))))" +end + +theory bv.BV128 + meta "literal:keep" type t + + syntax literal t "#x%32x" + syntax type t "(_ BitVec 128)" + + syntax function zeros "#x00000000000000000000000000000000" + syntax function one "#x00000000000000000000000000000001" + syntax function ones "#xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" + syntax function size_bv "(_ bv128 128)" + + syntax predicate is_signed_positive "(bvsge %1 (_ bv0 128))" + + syntax function rotate_left_bv "(bvor (bvshl %1 (bvurem %2 (_ bv128 128))) (bvlshr %1 (bvsub (_ bv128 128) (bvurem %2 (_ bv128 128)))))" + syntax function rotate_right_bv "(bvor (bvlshr %1 (bvurem %2 (_ bv128 128))) (bvshl %1 (bvsub (_ bv128 128) (bvurem %2 (_ bv128 128)))))" +end + +theory bv.BV64 + meta "literal:keep" type t + + syntax literal t "#x%16x" + syntax type t "(_ BitVec 64)" + + syntax function zeros "#x0000000000000000" + syntax function one "#x0000000000000001" + syntax function ones "#xFFFFFFFFFFFFFFFF" + syntax function size_bv "(_ bv64 64)" + + syntax predicate is_signed_positive "(bvsge %1 (_ bv0 64))" + + syntax function rotate_left_bv "(bvor (bvshl %1 (bvurem %2 (_ bv64 64))) (bvlshr %1 (bvsub (_ bv64 64) (bvurem %2 (_ bv64 64)))))" + syntax function rotate_right_bv "(bvor (bvlshr %1 (bvurem %2 (_ bv64 64))) (bvshl %1 (bvsub (_ bv64 64) (bvurem %2 (_ bv64 64)))))" +end + +theory bv.BV32 + meta "literal:keep" type t + + syntax literal t "#x%8x" + syntax type t "(_ BitVec 32)" + + syntax function zeros "#x00000000" + syntax function one "#x00000001" + syntax function ones "#xFFFFFFFF" + syntax function size_bv "(_ bv32 32)" + + syntax predicate is_signed_positive "(bvsge %1 (_ bv0 32))" + + syntax function rotate_left_bv "(bvor (bvshl %1 (bvurem %2 (_ bv32 32))) (bvlshr %1 (bvsub (_ bv32 32) (bvurem %2 (_ bv32 32)))))" + syntax function rotate_right_bv "(bvor (bvlshr %1 (bvurem %2 (_ bv32 32))) (bvshl %1 (bvsub (_ bv32 32) (bvurem %2 (_ bv32 32)))))" +end + +theory bv.BV16 + meta "literal:keep" type t + + syntax literal t "#x%4x" + syntax type t "(_ BitVec 16)" + + syntax function zeros "#x0000" + syntax function one "#x0001" + syntax function ones "#xFFFF" + syntax function size_bv "(_ bv16 16)" + + syntax predicate is_signed_positive "(bvsge %1 (_ bv0 16))" + + syntax function rotate_left_bv "(bvor (bvshl %1 (bvurem %2 (_ bv16 16))) (bvlshr %1 (bvsub (_ bv16 16) (bvurem %2 (_ bv16 16)))))" + syntax function rotate_right_bv "(bvor (bvlshr %1 (bvurem %2 (_ bv16 16))) (bvshl %1 (bvsub (_ bv16 16) (bvurem %2 (_ bv16 16)))))" +end + +theory bv.BV8 + meta "literal:keep" type t + + syntax literal t (* "#b%8b" *) "#x%2x" + syntax type t "(_ BitVec 8)" + + syntax function zeros "#x00" + syntax function one "#x01" + syntax function ones "#xFF" + syntax function size_bv "(_ bv8 8)" + + syntax predicate is_signed_positive "(bvsge %1 (_ bv0 8))" + + syntax function rotate_left_bv "(bvor (bvshl %1 (bvurem %2 (_ bv8 8))) (bvlshr %1 (bvsub (_ bv8 8) (bvurem %2 (_ bv8 8)))))" + syntax function rotate_right_bv "(bvor (bvlshr %1 (bvurem %2 (_ bv8 8))) (bvshl %1 (bvsub (_ bv8 8) (bvurem %2 (_ bv8 8)))))" +end + +theory bv.BVConverter_Gen + remove allprops +end + +theory bv.BVConverter_128_256 + syntax function toBig "((_ zero_extend 128) %1)" + syntax function stoBig "((_ sign_extend 128) %1)" + syntax function toSmall "((_ extract 127 0) %1)" +end + +theory bv.BVConverter_64_128 + syntax function toBig "((_ zero_extend 64) %1)" + syntax function stoBig "((_ sign_extend 64) %1)" + syntax function toSmall "((_ extract 63 0) %1)" +end + +theory bv.BVConverter_32_128 + syntax function toBig "((_ zero_extend 96) %1)" + syntax function stoBig "((_ sign_extend 96) %1)" + syntax function toSmall "((_ extract 31 0) %1)" +end + +theory bv.BVConverter_16_128 + syntax function toBig "((_ zero_extend 112) %1)" + syntax function stoBig "((_ sign_extend 112) %1)" + syntax function toSmall "((_ extract 15 0) %1)" +end + +theory bv.BVConverter_8_128 + syntax function toBig "((_ zero_extend 120) %1)" + syntax function stoBig "((_ sign_extend 120) %1)" + syntax function toSmall "((_ extract 7 0) %1)" +end + +theory bv.BVConverter_32_64 + syntax function toBig "((_ zero_extend 32) %1)" + syntax function stoBig "((_ sign_extend 32) %1)" + syntax function toSmall "((_ extract 31 0) %1)" +end + +theory bv.BVConverter_16_64 + syntax function toBig "((_ zero_extend 48) %1)" + syntax function stoBig "((_ sign_extend 48) %1)" + syntax function toSmall "((_ extract 15 0) %1)" +end + +theory bv.BVConverter_8_64 + syntax function toBig "((_ zero_extend 56) %1)" + syntax function stoBig "((_ sign_extend 56) %1)" + syntax function toSmall "((_ extract 7 0) %1)" +end + +theory bv.BVConverter_16_32 + syntax function toBig "((_ zero_extend 16) %1)" + syntax function stoBig "((_ sign_extend 16) %1)" + syntax function toSmall "((_ extract 15 0) %1)" +end + +theory bv.BVConverter_8_32 + syntax function toBig "((_ zero_extend 24) %1)" + syntax function stoBig "((_ sign_extend 24) %1)" + syntax function toSmall "((_ extract 7 0) %1)" +end + +theory bv.BVConverter_8_16 + syntax function toBig "((_ zero_extend 8) %1)" + syntax function stoBig "((_ sign_extend 8) %1)" + syntax function toSmall "((_ extract 7 0) %1)" +end + +theory bv.Pow2int + remove allprops +end diff --git a/config/drivers/smt-libv2-floats.gen b/config/drivers/smt-libv2-floats.gen new file mode 100644 index 0000000000000000000000000000000000000000..bb8d95642a8068256dc9b6374052a2cfc4834597 --- /dev/null +++ b/config/drivers/smt-libv2-floats.gen @@ -0,0 +1,170 @@ +theory ieee_float.RoundingMode + syntax type mode "RoundingMode" + syntax function RNE "RNE" + syntax function RNA "RNA" + syntax function RTP "RTP" + syntax function RTN "RTN" + syntax function RTZ "RTZ" + + syntax predicate to_nearest "(or (= %1 RNE) (= %1 RNA))" +end + +theory ieee_float.GenericFloat + (* Part I *) + syntax function abs "(fp.abs %1)" + syntax function neg "(fp.neg %1)" + + syntax function add "(fp.add %1 %2 %3)" + syntax function sub "(fp.sub %1 %2 %3)" + syntax function mul "(fp.mul %1 %2 %3)" + syntax function div "(fp.div %1 %2 %3)" + + syntax function fma "(fp.fma %1 %2 %3 %4)" + + syntax function sqrt "(fp.sqrt %1 %2)" + + syntax function roundToIntegral "(fp.roundToIntegral %1 %2)" + + syntax function min "(fp.min %1 %2)" + syntax function max "(fp.max %1 %2)" + + syntax predicate le "(fp.leq %1 %2)" + syntax predicate lt "(fp.lt %1 %2)" + syntax predicate ge "(fp.geq %1 %2)" + syntax predicate gt "(fp.gt %1 %2)" + + syntax predicate eq "(fp.eq %1 %2)" + + syntax predicate is_normal "(fp.isNormal %1)" + syntax predicate is_subnormal "(fp.isSubnormal %1)" + syntax predicate is_zero "(fp.isZero %1)" + syntax predicate is_infinite "(fp.isInfinite %1)" + syntax predicate is_nan "(fp.isNaN %1)" + syntax predicate is_positive "(fp.isPositive %1)" + syntax predicate is_negative "(fp.isNegative %1)" + + (* We could do this here, but we get slightly slimmer VCs and avoid + issues with Z3 if we do specialised versions of this for Float32 and + Float64 *) + (* The proposed fp.isFinite would fix all this. *) + (* syntax predicate is_finite "(not (or (fp.isInfinite %1) (fp.isNaN %1)))" *) + syntax predicate is_not_nan "(not (fp.isNaN %1))" + + syntax function to_real "(fp.to_real %1)" + + syntax predicate overflow_value "true" + + syntax predicate sign_zero_result "true" + + (* Part II *) + + remove allprops +end + +theory ieee_float.Float32 + (* Part I *) + syntax type t "Float32" + meta "literal:keep" type t + syntax literal t "(fp #b%s1b #b%e8b #b%m23b)" + + syntax function zeroF "((_ to_fp 8 24) #x00000000)" + + prelude "(define-fun fp.isFinite32 ((x Float32)) Bool (not (or (fp.isInfinite x) (fp.isNaN x))))" + prelude "(define-fun fp.isIntegral32 ((x Float32)) Bool (or (fp.isZero x) (and (fp.isNormal x) (= x (fp.roundToIntegral RNE x)))))" + + + syntax predicate t'isFinite "(fp.isFinite32 %1)" + syntax predicate is_int "(fp.isIntegral32 %1)" + + (* Faithful translations of the axiomatisation, mainly to remove this crud + from the smtlib output of SPARK. *) + syntax function round "(fp.to_real ((_ to_fp 8 24) %1 %2))" + + syntax predicate in_range "\ + (and (<= (fp.to_real (fp #b1 #b11111110 #b11111111111111111111111)) %1) \ + (<= %1 (fp.to_real (fp #b0 #b11111110 #b11111111111111111111111))))" + + syntax predicate no_overflow "\ + (and (<= (fp.to_real (fp #b1 #b11111110 #b11111111111111111111111)) \ + (fp.to_real ((_ to_fp 8 24) %1 %2))) \ + (<= (fp.to_real ((_ to_fp 8 24) %1 %2)) \ + (fp.to_real (fp #b0 #b11111110 #b11111111111111111111111))))" + + remove allprops +end + +theory ieee_float.Float64 + (* Part I *) + syntax type t "Float64" + meta "literal:keep" type t + syntax literal t "(fp #b%s1b #b%e11b #b%m52b)" + + syntax function zeroF "((_ to_fp 11 53) #x0000000000000000)" + + prelude "(define-fun fp.isFinite64 ((x Float64)) Bool (not (or (fp.isInfinite x) (fp.isNaN x))))" + prelude "(define-fun fp.isIntegral64 ((x Float64)) Bool (or (fp.isZero x) (and (fp.isNormal x) (= x (fp.roundToIntegral RNE x)))))" + + syntax predicate t'isFinite "(fp.isFinite64 %1)" + syntax predicate is_int "(fp.isIntegral64 %1)" + + (* Faithful translations of the axiomatisation, mainly to remove this crud + from the smtlib output of SPARK. *) + syntax function round "(fp.to_real ((_ to_fp 11 53) %1 %2))" + + syntax predicate in_range "\ + (and \ + (<= \ + (fp.to_real (fp #b1 #b11111111110 #b1111111111111111111111111111111111111111111111111111)) \ + %1) \ + (<= %1 \ + (fp.to_real (fp #b0 #b11111111110 #b1111111111111111111111111111111111111111111111111111))))" + + syntax predicate no_overflow "\ + (and \ + (<= \ + (fp.to_real (fp #b1 #b11111111110 #b1111111111111111111111111111111111111111111111111111)) \ + (fp.to_real ((_ to_fp 11 53) %1 %2))) \ + (<= \ + (fp.to_real ((_ to_fp 11 53) %1 %2)) \ + (fp.to_real (fp #b0 #b11111111110 #b1111111111111111111111111111111111111111111111111111))))" + + remove allprops +end + +theory ieee_float.FloatConverter + (* Part I *) + syntax function to_float32 "((_ to_fp 8 24) %1 %2)" + syntax function to_float64 "((_ to_fp 11 53) %1 %2)" + + remove allprops +end + +theory ieee_float.Float_BV_Converter + (* Part I *) + syntax function to_ubv8 "((_ fp.to_ubv 8) %1 %2)" + syntax function to_ubv16 "((_ fp.to_ubv 16) %1 %2)" + syntax function to_ubv32 "((_ fp.to_ubv 32) %1 %2)" + syntax function to_ubv64 "((_ fp.to_ubv 64) %1 %2)" + + remove allprops +end + +theory ieee_float.Float32_BV_Converter + (* Part I *) + syntax function of_ubv8 "((_ to_fp_unsigned 8 24) %1 %2)" + syntax function of_ubv16 "((_ to_fp_unsigned 8 24) %1 %2)" + syntax function of_ubv32 "((_ to_fp_unsigned 8 24) %1 %2)" + syntax function of_ubv64 "((_ to_fp_unsigned 8 24) %1 %2)" + + remove allprops +end + +theory ieee_float.Float64_BV_Converter + (* Part I *) + syntax function of_ubv8 "((_ to_fp_unsigned 11 53) %1 %2)" + syntax function of_ubv16 "((_ to_fp_unsigned 11 53) %1 %2)" + syntax function of_ubv32 "((_ to_fp_unsigned 11 53) %1 %2)" + syntax function of_ubv64 "((_ to_fp_unsigned 11 53) %1 %2)" + + remove allprops +end diff --git a/config/drivers/smt-libv2.gen b/config/drivers/smt-libv2.gen new file mode 100644 index 0000000000000000000000000000000000000000..5e58efa3e858a7cdbd258e08b419cb654c40481e --- /dev/null +++ b/config/drivers/smt-libv2.gen @@ -0,0 +1,196 @@ +(* Why3 driver for SMT-LIB2 syntax, excluding bit-vectors *) + +prelude ";;; generated by SMT-LIB2 driver" + +(* + +Note: we do not insert any command "set-logic" because its +interpretation is specific to provers + +prelude "(set-logic AUFNIRA)" + + A : Array + UF : Uninterpreted Function + DT : Datatypes (not needed at the end ...) + NIRA : NonLinear Integer Reals Arithmetic + +*) + +filename "%f-%t-%g.smt2" +unknown "^\\(unknown\\|sat\\|Fail\\)$" "\\1" +unknown "^(:reason-unknown \\([^)]*\\))$" "\\1" +time "why3cpulimit time : %s s" +valid "^unsat$" + +theory BuiltIn + syntax type int "Int" + syntax type real "Real" + syntax predicate (=) "(= %1 %2)" + + meta "encoding:kept" type int + meta "encoding:ignore_polymorphism_ls" predicate (=) +end + +theory int.Int + + prelude ";;; SMT-LIB2: integer arithmetic" + + syntax function zero "0" + syntax function one "1" + + syntax function (+) "(+ %1 %2)" + syntax function (-) "(- %1 %2)" + syntax function (*) "(* %1 %2)" + syntax function (-_) "(- %1)" + + syntax predicate (<=) "(<= %1 %2)" + syntax predicate (<) "(< %1 %2)" + syntax predicate (>=) "(>= %1 %2)" + syntax predicate (>) "(> %1 %2)" + + remove prop MulComm.Comm + remove prop MulAssoc.Assoc + remove prop Unit_def_l + remove prop Unit_def_r + remove prop Inv_def_l + remove prop Inv_def_r + remove prop Assoc + remove prop Mul_distr_l + remove prop Mul_distr_r + remove prop Comm + remove prop Unitary + remove prop Refl + remove prop Trans + remove prop Antisymm + remove prop Total + remove prop NonTrivialRing + remove prop CompatOrderAdd + remove prop ZeroLessOne + +end + +theory real.Real + + prelude ";;; SMT-LIB2: real arithmetic" + + syntax function zero "0.0" + syntax function one "1.0" + + syntax function (+) "(+ %1 %2)" + syntax function (-) "(- %1 %2)" + syntax function (*) "(* %1 %2)" + syntax function (/) "(/ %1 %2)" + syntax function (-_) "(- %1)" + syntax function inv "(/ 1.0 %1)" + + syntax predicate (<=) "(<= %1 %2)" + syntax predicate (<) "(< %1 %2)" + syntax predicate (>=) "(>= %1 %2)" + syntax predicate (>) "(> %1 %2)" + + remove prop MulComm.Comm + remove prop MulAssoc.Assoc + remove prop Unit_def_l + remove prop Unit_def_r + remove prop Inv_def_l + remove prop Inv_def_r + remove prop Assoc + remove prop Mul_distr_l + remove prop Mul_distr_r + remove prop Comm + remove prop Unitary + remove prop Inverse + remove prop Refl + remove prop Trans + remove prop Antisymm + remove prop Total + remove prop NonTrivialRing + remove prop CompatOrderAdd + remove prop ZeroLessOne + + meta "encoding:kept" type real + +end + +theory real.Abs + syntax function abs "(ite (>= %1 0.0) %1 (- %1))" + + remove allprops +end + +theory real.MinMax + + remove allprops +end + +theory real.FromInt + syntax function from_int "(to_real %1)" + + remove allprops +end + +theory real.Truncate + syntax function truncate "(ite (>= %1 0.0) \ + (to_int %1) \ + (- (to_int (- %1))))" + syntax function floor "(to_int %1)" + syntax function ceil "(- (to_int (- %1)))" + + remove allprops +end + +theory Bool + syntax type bool "Bool" + syntax function True "true" + syntax function False "false" + + meta "encoding:kept" type bool +end + +theory bool.Bool + syntax function andb "(and %1 %2)" + syntax function orb "(or %1 %2)" + syntax function xorb "(xor %1 %2)" + syntax function notb "(not %1)" + syntax function implb "(=> %1 %2)" +end + +theory bool.Ite + syntax function ite "(ite %1 %2 %3)" + meta "encoding:lskept" function ite + meta "encoding:ignore_polymorphism_ls" function ite +end + +(* not uniformly interpreted by provers +theory real.Truncate + syntax function floor "(to_int %1)" + remove prop Floor_down + remove prop Floor_monotonic +end +*) + +theory HighOrd + syntax type (->) "(Array %1 %2)" + syntax function (@) "(select %1 %2)" + + meta "encoding:lskept" function (@) + meta "encoding:ignore_polymorphism_ts" type (->) + meta "encoding:ignore_polymorphism_ls" function (@) +end + +theory map.Map + syntax function get "(select %1 %2)" + syntax function set "(store %1 %2 %3)" + + meta "encoding:lskept" function get + meta "encoding:lskept" function set + meta "encoding:ignore_polymorphism_ls" function get + meta "encoding:ignore_polymorphism_ls" function ([]) + meta "encoding:ignore_polymorphism_ls" function set + meta "encoding:ignore_polymorphism_ls" function ([<-]) +end + +theory map.Const + meta "encoding:lskept" function const +(* syntax function const "(const[%t0] %1)" *) +end diff --git a/config/dune b/config/dune index 4398e3dca7dcf0af0c46d0b0923bce5182889c8c..2716e01b5387d6b98f167c9443284fbd7fca2baa 100644 --- a/config/dune +++ b/config/dune @@ -6,5 +6,13 @@ caisar-detection-data.conf (drivers/pyrat.drv as drivers/pyrat.drv) (drivers/marabou.drv as drivers/marabou.drv) - (drivers/saver.drv as drivers/saver.drv)) + (drivers/saver.drv as drivers/saver.drv) + (drivers/cvc5.drv as drivers/cvc5.drv) + (drivers/cvc4_16.gen as drivers/cvc4_16.gen) + (drivers/cvc4_bv.gen as drivers/cvc4_bv.gen) + (drivers/smt-libv2-bv.gen as drivers/smt-libv2-bv.gen) + (drivers/smt-libv2-floats.gen as drivers/smt-libv2-floats.gen) + (drivers/smt-libv2.gen as drivers/smt-libv2.gen) + (drivers/discrimination.gen as drivers/discrimination.gen) + ) (package caisar)) diff --git a/dune-project b/dune-project index e2b7cc99e4891e854084ef877e19bc32d6fc8dd9..8d1d7d4e4b5e30813740dba7765442460b82774c 100644 --- a/dune-project +++ b/dune-project @@ -48,6 +48,20 @@ ) ) +(package + (name caisar-ir) + (synopsis "CAISAR's intermediate representation") + (depends + (ocaml (>= 4.13)) + (dune (>= 2.9.3)) + (base (>= v0.14.0)) + (ocaml-protoc-plugin (= 4.2.0)) + (ocamlgraph (>= 1.8.8)) + (ppx_inline_test (>= 0.12.0)) + (ppx_deriving (>= 4.4.1)) + ) +) + (package (name caisar) (synopsis "A platform for characterizing the safety and robustness of artificial intelligence based software") @@ -73,6 +87,7 @@ (caisar-nnet (= :version)) (caisar-ovo (= :version)) (caisar-onnx (= :version)) + (caisar-ir (= :version)) ) (sites (share stdlib) diff --git a/lib/ir/dune b/lib/ir/dune new file mode 100644 index 0000000000000000000000000000000000000000..e7d5f582a639f094be25b03fa2d47e57ba1628c0 --- /dev/null +++ b/lib/ir/dune @@ -0,0 +1,17 @@ +(library + (name ir) + (public_name caisar-ir) + (preprocess + (pps + ppx_inline_test + ppx_deriving.map + ppx_deriving.show + ppx_deriving.iter + ppx_deriving.fold)) + (inline_tests) + (libraries base ocplib-endian piqirun.pb zarith ocamlgraph stdio)) + +(env + (dev + (flags (:standard -warn-error -A)) + )) diff --git a/lib/ir/nier_cfg.ml b/lib/ir/nier_cfg.ml new file mode 100644 index 0000000000000000000000000000000000000000..a90a16848e545c6b6ca2276ca35121ee5bae5d12 --- /dev/null +++ b/lib/ir/nier_cfg.ml @@ -0,0 +1,481 @@ +open Base +open Stdio +open Bigarray + +module Tensor = struct + type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t + type shape = int array [@@deriving show] + + type ('a, 'b) t_kind = + | K_int : (int64, int64_elt) t_kind + | K_float : (float, float64_elt) t_kind + + let create : type a b. shape -> (a, b) t_kind -> (a, b) t = + fun shape -> function + | K_float -> Genarray.create float64 c_layout shape + | K_int -> Genarray.create int64 c_layout shape + + let unsqueeze ~sh1 ~sh2 = + let sh1, sh2 = (Array.to_list sh1, Array.to_list sh2) in + let longest, shortest = + match List.length sh1 > List.length sh2 with + | true -> (sh1, sh2) + | false -> (sh2, sh1) + in + (*find the index of the potential additional dimension*) + let where_zero = + match List.nth_exn longest 0 with + | 0 -> Some 0 + | _ -> ( + match List.last_exn longest with + | 0 -> Some (List.length longest - 1) + | _ -> None) + in + match where_zero with + | Some idx -> ( + match List.sub longest ~pos:idx ~len:(List.length shortest) with + | [] -> None + | _ -> Some (Array.of_list longest)) + | None -> None + + let get t idx = Genarray.get t idx + let set t idx v = Genarray.set t idx v + + let all_coords sh = + let sh = Array.to_list sh in + let rec ranges acc shape = + match shape with + | x :: y -> ranges (List.init x ~f:(fun i -> i) :: acc) y + | [] -> acc + (* a list containing a list of all possible indexes, for each dimension *) + in + let xxs = ranges [] sh in + (* add to each element of the list of all possible coordinates all*) + (* * possible indexes ... *) + let aux acc xs = + List.concat + @@ List.map xs ~f:(fun x -> List.map ~f:(fun lt -> x :: lt) acc) + (* ... for each dimension, starting from an empty list of*) + (* * possible coordinates *) + in + List.fold xxs ~init:[ [] ] ~f:aux + + let flatten t = + let shape = Genarray.dims t in + let all_idxs = all_coords shape in + List.init (List.length all_idxs) ~f:(fun i -> + get t (Array.of_list @@ List.nth_exn all_idxs i)) + + let get_shape t = Genarray.dims t + + let equal f t1 t2 = + let t1_sh = get_shape t1 and t2_sh = get_shape t2 in + if Array.equal ( = ) t1_sh t2_sh + then + let all_idxs = all_coords (get_shape t1) in + List.fold + ~f:(fun acc x -> + if acc + then f (get t1 (Array.of_list x)) (get t2 (Array.of_list x)) + else false) + all_idxs ~init:true + else false + + let num_neurons sh = Array.fold ~init:1 ~f:(fun x y -> x * y) sh + + let get_flatnd_idx ~idx ~sh flt = + let sh = Array.to_list sh in + let pop_sh = List.tl_exn sh @ [ 1 ] in + let rec get_factors_from_sh sh_f l = + match sh_f with + | [] -> List.rev l + | _ -> + get_factors_from_sh (List.tl_exn sh_f) + (List.fold ~f:(fun x y -> x * y) ~init:1 sh_f :: l) + in + let factors = get_factors_from_sh pop_sh [] in + let coord_in_data = + match + List.fold2 + ~f:(fun x y z -> x + (y * z)) + ~init:0 (Array.to_list idx) factors + with + | List.Or_unequal_lengths.Ok i -> i + | List.Or_unequal_lengths.Unequal_lengths -> + failwith "Unequal lengths in get_flatnd_idx" + in + List.nth_exn flt coord_in_data + + let transpose_2d _t = assert false +end + +(* TODO: maybe add markers for special nodes, to reflect they are the inputs and + outputs of the neural network? *) +module Node = struct + type shape = int array + + let show_shape sh = + let sh = Array.to_list sh in + match sh with + | [] -> "[]" + | x :: y -> + "[" ^ Int.to_string x + ^ List.fold ~init:"" ~f:(fun str e -> str ^ ";" ^ Int.to_string e) y + ^ "]" + + type operator = + | Add + | Mul + | Matmul + | LogSoftmax + | ReLu + | Transpose + | Squeeze + | MaxPool + | Conv + | Identity + | NO_OP + | RW_Linearized_ReLu + + let str_op o = + match o with + | Add -> "Add" + | Mul -> "Mul" + | Matmul -> "Matmul" + | LogSoftmax -> "LogSoftmax" + | ReLu -> "ReLu" + | Transpose -> "Transpose" + | Squeeze -> "Squeeze" + | MaxPool -> "MaxPool" + | Conv -> "Conv" + | Identity -> "Identity" + | NO_OP -> "NO_OP" + | RW_Linearized_ReLu -> "RW_Linearized_ReLu" + + type ksize = Ksize of shape + type stride = Stride of shape + type pads = Pads of shape + type dilations = Dilations of shape + + type operator_parameters = + | Pool_params of (ksize * stride option * pads option * dilations option) + | Conv_params of (ksize * stride option * pads option * dilations option) + | Transpose_params of shape + | RW_Linearized_ReLu_params of + (bool list list * ((string, float) Base.Hashtbl.t list * int)) + + let str_op_params p = + match p with + | Transpose_params s -> + let str_sh = show_shape s in + "Transpose params: " ^ str_sh + | Pool_params (Ksize k, s, p, d) | Conv_params (Ksize k, s, p, d) -> + let str_k = show_shape k + and str_s = match s with None -> "" | Some (Stride ss) -> show_shape ss + and str_p = match p with None -> "" | Some (Pads pp) -> show_shape pp + and str_d = + match d with None -> "" | Some (Dilations dd) -> show_shape dd + in + "Pool params: KSIZE: " ^ str_k ^ ", Pads: " ^ str_p ^ ", Stride: " ^ str_s + ^ ", Dilations: " ^ str_d + | RW_Linearized_ReLu_params l -> + (* Only displays the activation scheme on the ReLU node *) + let activations = fst l in + let act' = + List.map + ~f:(fun l1 -> + List.map + ~f:(fun b -> match b with true -> "true" | false -> "false") + l1) + activations + in + let act'' = + List.map ~f:(fun l -> "[" ^ String.concat ~sep:";" l ^ "]") act' + in + let act''' = "[" ^ String.concat ~sep:";" act'' ^ "]" in + "RW_Linearized_ReLu_params: " ^ act''' + + type ('a, 'b) t = { + id : int; + name : string option; + shape : shape; + operator : operator; + operator_parameters : operator_parameters option; + pred : string list; + succ : string list; + tensor : ('a, 'b) Tensor.t option; + } + + let compare v1 v2 = Stdlib.compare v1.id v2.id + let hash (v : ('a, 'b) t) = v.id + let equal v1 v2 = v1.id = v2.id + + let create ~id ~name ~sh ~op ~op_p ~pred ~succ ~tensor = + { + id; + name; + shape = sh; + operator = op; + operator_parameters = op_p; + pred; + succ; + tensor; + } + + let get_name t = match t.name with Some n -> n | None -> "C_NODE" + let get_shape t = t.shape + let get_op t = t.operator + let get_tensor t = t.tensor + let get_pred_list t = t.pred + let get_succ_list t = t.succ + let is_data_node t = match get_tensor t with None -> false | Some _ -> true + + (* TODO: some flags on the node would be cleaner than this*) + let is_input_node t = List.equal String.equal t.pred [ "NO_INPUT" ] + let is_output_node t = List.equal String.equal t.succ [ "NO_OUTPUT" ] + + let num_neurons t = + match get_shape t with + | [||] -> 0 + | l -> Array.fold ~init:1 ~f:(fun x acc -> x * acc) l + + let show n f = + let id = Int.to_string n.id in + let name = get_name n + and operator = str_op n.operator + and operator_parameters = + match n.operator_parameters with + | Some p -> str_op_params p + | None -> "no parameters" + and shape = show_shape n.shape + and prevs = + List.fold_left ~f:(fun x y -> x ^ "," ^ y) ~init:"" (get_pred_list n) + and nexts = + List.fold_left ~f:(fun x y -> x ^ "," ^ y) ~init:"" (get_succ_list n) + and tensor = + match n.tensor with + (*limit of size for tensor strings, complying with + * dot string size limit of 16Ko *) + | Some t -> + let display_indices = + let all_indices = Tensor.all_coords (Tensor.get_shape t) in + if List.length all_indices > 10 + then + let rec firstk k xs = + match xs with + | [] -> failwith "firstk" + | x :: xs -> if k = 1 then [ x ] else x :: firstk (k - 1) xs + in + firstk 10 all_indices + else all_indices + in + let t_value_string f = + List.fold_left + ~f:(fun acc l -> + acc + ^ show_shape (Array.of_list l) + ^ ": " + ^ f (Tensor.get t (Array.of_list l)) + ^ "\n") + ~init:"" display_indices + in + "Tensor value\n: " ^ t_value_string f ^ "\nShape: " + ^ show_shape (Tensor.get_shape t) + | None -> "No tensor in node" + in + "ID :" ^ id ^ "\nNAME: " ^ name ^ "\nOP: " ^ operator ^ "\nOP PARAMS:" + ^ operator_parameters ^ "\nSHAPE: " ^ shape ^ "\nPREVS: " ^ prevs + ^ "\nNEXTS: " ^ nexts ^ "\nTENSORS INFOS:" ^ tensor +end + +module type VInput = sig + type l + type r + + val convert_f : l -> string +end + +module MakeVertex (I : VInput) = struct + type t = (I.l, I.r) Node.t + + let compare = Node.compare + let hash = Node.hash + let equal = Node.equal + let convert_f = I.convert_f + + type label = string + + let label (n : t) = match n.Node.name with Some n -> n | None -> "" + let create _name = assert false +end + +module Edge = struct + type t = string + + let compare = Stdlib.compare + let equal = phys_equal + let default = "" +end + +module NierCFG (I : VInput) = struct + module Vertex = MakeVertex (I) + include Graph.Imperative.Digraph.ConcreteBidirectionalLabeled (Vertex) (Edge) + + let convert_f = Vertex.convert_f + let vertex_list g = fold_vertex (fun x l -> x :: l) g [] + + let input_nodes g = + let input_criterion (v : ('a, 'b) Node.t) acc = + match v.id with 0 -> Some v | _ -> acc + in + match fold_vertex (fun v acc -> input_criterion v acc) g None with + | Some r -> [ r ] + | None -> failwith "Something strange, no node for describing inputs found" + + let preds g v = pred g v + + let preds_names g v = + let preds_list = pred_e g v in + List.fold ~init:[] ~f:(fun acc (_, n, _) -> n :: acc) preds_list + + let succs_names g v = + let succs_list = succ_e g v in + List.fold ~init:[] ~f:(fun acc (_, n, _) -> n :: acc) succs_list + + let succs g v = succ g v + let init_cfg = create () + + let find_vertices g f = + fold_vertex (fun x l -> if f x then x :: l else l) g [] + + let data_node_of n g = + fold_pred (fun v _ -> if Node.is_data_node v then Some v else None) g n None + + let infer_shape g n in_shape ~on_backward = + let op = Node.get_op n in + match op with + | Node.Add -> ( + match data_node_of n g with + | Some d_n -> Node.get_shape d_n + | None -> failwith "Error, Add operator lacks a data node") + | Node.ReLu -> in_shape + | Node.Matmul -> + let pad_left = function + | [] -> failwith "Impossible to pad empty shape" + | [ a ] -> [ 1; a ] + | x -> x + in + let pad_right = function + | [] -> failwith "Impossible to pad empty shape" + | [ a ] -> [ a; 1 ] + | x -> x + in + let rec one_padding l i = + if i <= 0 then l else one_padding (1 :: l) (i - 1) + in + let dn_shape = + match data_node_of n g with + | Some dn -> Node.get_shape dn + | None -> failwith "Error, MatMul operator lacks a data node" + in + (* Expected semantic: + * Matrix multiplication C = AB + * A (shape [n;m]); B (shape [m;p]); C (shape [n;p]) + * shape of b: b_sh + * shape of a: a_sh + * shape of c: c_sh + * It is expected here that B is the shape of the node + * yielding the data tensor in the NIER + *) + let check_matmul_size_ba ~b_sh ~a_sh = + let bdim2 = pad_left b_sh in + let adim2 = pad_right a_sh in + let bdim = one_padding bdim2 (List.length adim2 - List.length bdim2) in + let adim = one_padding adim2 (List.length bdim2 - List.length adim2) in + let rec infer_csize acc ad bd = + match (ad, bd) with + | [ m; n ], [ nn; p ] -> + if nn = n + then (n, List.append (List.rev acc) [ m; p ]) + else failwith "size of matrices not adequate" + | a :: la, b :: lb -> + if a = b + then infer_csize (a :: acc) la lb + else if a = 1 + then infer_csize (b :: acc) la lb + else if b = 1 + then infer_csize (a :: acc) la lb + else failwith "Checking matmul_size failed: one discordance" + | _, _ -> failwith "Checking matmul_size failed" + in + infer_csize [] bdim adim + in + let check_matmul_size_bc ~b_sh ~c_sh = + let bdim2 = pad_left b_sh in + let cdim2 = pad_right c_sh in + let bdim = one_padding bdim2 (List.length cdim2 - List.length bdim2) in + let cdim = one_padding cdim2 (List.length bdim2 - List.length cdim2) in + let rec infer_asize acc bd cd = + match (bd, cd) with + | [ m; p ], [ n; pp ] -> + if pp = p + then (n, List.append (List.rev acc) [ n; m ]) + else failwith "size of matrices not adequate" + | b :: lb, c :: lc -> + if b = c + then infer_asize (b :: acc) lb lc + else if b = 1 + then infer_asize (b :: acc) lb lc + else if c = 1 + then infer_asize (c :: acc) lb lc + else failwith "Checking matmul_size failed: one discordance" + | _, _ -> failwith "Checking matmul_size failed" + in + infer_asize [] bdim cdim + in + if on_backward + then + Array.of_list + @@ snd + (check_matmul_size_bc ~b_sh:(Array.to_list dn_shape) + ~c_sh:(Array.to_list in_shape)) + else + Array.of_list + @@ snd + (check_matmul_size_ba ~b_sh:(Array.to_list in_shape) + ~a_sh:(Array.to_list dn_shape)) + | a -> failwith (Printf.sprintf "operator %s not supported" (Node.str_op a)) +end + +module NierCFGInt = NierCFG (struct + type l = int64 + type r = int64_elt + + let convert_f = Int64.to_string +end) + +module NierCFGFloat = NierCFG (struct + type l = float + type r = float64_elt + + let convert_f = Float.to_string +end) + +module NierCFGDot = Graph.Graphviz.Dot (struct + include NierCFGFloat (* use the graph module from above *) + + let node_label (v : vertex) = Node.show v convert_f + let edge_attributes (_, e, _) = [ `Label e; `Color 4711 ] + let default_edge_attributes _ = [] + let get_subgraph _ = None + let vertex_attributes v = [ `Shape `Box; `Label (node_label v) ] + let vertex_name (v : vertex) = Int.to_string v.id + let default_vertex_attributes _ = [] + let graph_attributes _ = [] +end) + +let print_cfg_graph g = NierCFGDot.fprint_graph Caml.Format.std_formatter g + +let out_cfg_graph g = + let file = Out_channel.create "cfg.dot" in + NierCFGDot.output_graph file g diff --git a/lib/ir/nier_cfg.mli b/lib/ir/nier_cfg.mli new file mode 100644 index 0000000000000000000000000000000000000000..fd8394a50dd1d05ca39faa9ff99b4689b267bdd5 --- /dev/null +++ b/lib/ir/nier_cfg.mli @@ -0,0 +1,268 @@ +(** This module defines the structure and interfaces for a Neural IntermediatE + Representation (NIER). + + It is primarly designed as an intermediate state into producing verifiable + terms from an ONNX model. *) + +open Base +open Bigarray + +(** {1 Tensor module} *) + +(** Tensors are multidimensional arrays used to represent numerical such as a + neural network weight *) + +module Tensor : sig + type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t + type shape = int array [@@deriving show] + + val all_coords : shape -> int list list + + (** [create sh] initialize a tensor with the given shape [sh] with a default + value, depending of the type of the tensor*) + + type ('a, 'b) t_kind = + | K_int : (int64, int64_elt) t_kind + | K_float : (float, float64_elt) t_kind + + val create : shape -> ('a, 'b) t_kind -> ('a, 'b) t + + (** [get t idx] returns the value in tensor [t] stored at coordinates [idx]. + Throw an error if the coordinate is invalid.*) + + val get : ('a, 'b) t -> shape -> 'a + + (** [set_idx t idx v] sets value [v] for tensor [t] at [idx]. Throw an error + if the coordinate is invalid.*) + + val set : ('a, 'b) t -> shape -> 'a -> unit + + (** [equal f t1 t2] applies [f] to all values of [t1] and [t2], and returns + true if all applications of f returned true. *) + + val equal : ('a -> 'a -> bool) -> ('a, 'b) t -> ('a, 'b) t -> bool + + (** [get_shape t] returns the shape of [t]. *) + + val get_shape : ('a, 'b) t -> shape + + (** [flatten t] returns a flattened version of [t]. *) + + val flatten : ('a, 'b) t -> 'a list + + (** [num_neurons sh] returns the total number of neurons given a shape *) + + val num_neurons : shape -> int + + (** [get flatnd_idx idx sh flt] returns the value that would be stored at + index [idx] under a tensor of shape [sh], given the flattened version of + this tensor [flt].*) + + val get_flatnd_idx : idx:shape -> sh:shape -> 'a list -> 'a + + (** [transpose_2d t] returns a copy of the tensor [t] with its two last + dimension exchanged.*) + + val transpose_2d : ('a, 'b) t -> ('a, 'b) t + + (** [unsqueeze sh1 sh2] returns the lowest common shape between [sh1] and + [sh2], and None if there is no common shape. A common shape is when a + shape of higher dimension has only 1 coordinates on non-shared dimensions + with the other. *) + + val unsqueeze : sh1:shape -> sh2:shape -> shape option +end + +(** {1 Modules for graph generation} *) + +module Node : sig + type shape = int array + + type operator = + | Add + | Mul + | Matmul + | LogSoftmax + | ReLu + | Transpose + | Squeeze + | MaxPool + | Conv + | Identity + | NO_OP + | RW_Linearized_ReLu + + (** Type describing the different operations handled. Those operations are + inspired by those defined in the ONNX documentation. + + @see <https://github.com/onnx/onnx/blob/master/docs/Operators.md> + for more informations. They are to be coupled with the relevant + operators parameters. *) + + val str_op : operator -> string + val show_shape : shape -> string + + type ksize = Ksize of shape + type stride = Stride of shape + type pads = Pads of shape + type dilations = Dilations of shape + + type operator_parameters = + | Pool_params of (ksize * stride option * pads option * dilations option) + | Conv_params of (ksize * stride option * pads option * dilations option) + | Transpose_params of shape + | RW_Linearized_ReLu_params of + (bool list list * ((string, float) Base.Hashtbl.t list * int)) + + val str_op_params : operator_parameters -> string + + type ('a, 'b) t = { + id : int; + name : string option; + shape : shape; + operator : operator; + operator_parameters : operator_parameters option; + pred : string list; + succ : string list; + tensor : ('a, 'b) Tensor.t option; + } + (** Type encapsulating parameters for operations. For Convolutions and + Pooling, kernel size, padding, strides For Transpose, shape *) + + val compare : ('a, 'b) t -> ('a, 'b) t -> int + val hash : ('a, 'b) t -> int + val equal : ('a, 'b) t -> ('a, 'b) t -> bool + + val create : + id:int -> + name:string option -> + sh:shape -> + op:operator -> + op_p:operator_parameters option -> + pred:string list -> + succ:string list -> + tensor:('a, 'b) Tensor.t option -> + ('a, 'b) t + + val get_name : ('a, 'b) t -> string + val get_shape : ('a, 'b) t -> shape + val get_op : ('a, 'b) t -> operator + val get_pred_list : ('a, 'b) t -> string list + val get_succ_list : ('a, 'b) t -> string list + val get_tensor : ('a, 'b) t -> ('a, 'b) Tensor.t option + val is_data_node : ('a, 'b) t -> bool + val is_input_node : ('a, 'b) t -> bool + val is_output_node : ('a, 'b) t -> bool + val num_neurons : ('a, 'b) t -> int + val show : ('a, 'b) t -> ('a -> string) -> string +end + +module type VInput = sig + type l + type r + + val convert_f : l -> string +end + +module MakeVertex (I : VInput) : sig + include Graph.Sig.VERTEX with type t = (I.l, I.r) Node.t +end + +module Edge : sig + type t = string + + val compare : 'a -> 'a -> int + val equal : 'a -> 'a -> bool + val default : t +end + +(** NIER is a graph {b (V,E)} where {b V} is the set of vertices (nodes) and + {b E} is the set of edges (connections between nodes). Nodes contains the + following informations: + + - unique id + - name coming from the original model, if it exists + - shape of the tensor resulting from the application of the node operation, + if it exist + - operation performed + - parameters of the operation + - an optional tensor storing the data + + Note that tensor have their own shape; they must be equal to the NIER's node + shape however. *) + +module NierCFG (I : VInput) : sig + include + Graph.Sig.I + with type V.t = MakeVertex(I).t + and type V.label = MakeVertex(I).t + and type E.t = MakeVertex(I).t * Edge.t * MakeVertex(I).t + and type E.label = Edge.t + + val init_cfg : t + val vertex_list : t -> vertex list + val preds : t -> vertex -> vertex list + + (** [preds_names g v] returns a list of names of predecessors nodes *) + + val preds_names : t -> vertex -> string list + val succs : t -> vertex -> vertex list + + (** [succs_names g v] returns a list of names of predecessors nodes *) + + val succs_names : t -> vertex -> string list + + (** [input_node g] returns the nodes considered as describing the inputs of + the neural network. *) + + val input_nodes : t -> vertex list + val find_vertices : t -> (vertex -> bool) -> vertex list +end + +module NierCFGFloat : sig + include + Graph.Sig.I + with type V.t = (float, Bigarray.float64_elt) Node.t + and type V.label = (float, Bigarray.float64_elt) Node.t + and type E.t = + (float, Bigarray.float64_elt) Node.t + * Edge.t + * (float, Bigarray.float64_elt) Node.t + and type E.label = Edge.t + + val init_cfg : t + val vertex_list : t -> vertex list + val preds : t -> vertex -> vertex list + + (** [preds_names g v] returns a list of names of predecessors nodes *) + + val preds_names : t -> vertex -> string list + val succs : t -> vertex -> vertex list + + (** [succs_names g v] returns a list of names of predecessors nodes *) + + val succs_names : t -> vertex -> string list + + (** [input_node g] returns the nodes considered as describing the inputs of + the neural network. *) + + val input_nodes : t -> vertex list + val find_vertices : t -> (vertex -> bool) -> vertex list + + (** [data_node_of n ] returns one node containing a tensor * data among the + predecessors of [n]*) + + val data_node_of : vertex -> t -> vertex option + + (** [infer_shape g n sh o_b] returns the inferred shape of the output of node + [n] in NIER [g] with input shape [sh]. Shape inference is made using the + node operator and its predecessors shapes. [o_b] is true when performing + backward propagation, to choose which matrix size to consider. *) + + val infer_shape : t -> vertex -> Node.shape -> on_backward:bool -> Node.shape +end + +(** {1 Pretty printers} *) + +val print_cfg_graph : NierCFGFloat.t -> unit +val out_cfg_graph : NierCFGFloat.t -> unit diff --git a/lib/onnx/dune b/lib/onnx/dune index 1008baf43c6b2f5790fbed38985ca6e9e364c3ca..aff7ac603d9995096fa6fdbf265cab8eec9f6475 100644 --- a/lib/onnx/dune +++ b/lib/onnx/dune @@ -1,7 +1,7 @@ (library (name onnx) (public_name caisar-onnx) - (libraries base stdio ocaml-protoc-plugin) + (libraries base stdio ocaml-protoc-plugin caisar-ir) (synopsis "ONNX parser for CAISAR")) (rule diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 2ee9221d2a1eba14f78c7e8629034192ff234879..3a8a8794799647e06fa37d4c5da919a52221c023 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -25,6 +25,10 @@ module Format = Caml.Format module Fun = Caml.Fun module Oproto = Onnx_protoc (* Autogenerated during compilation *) module Oprotom = Oproto.Onnx.ModelProto +module NCFG = Ir.Nier_cfg +module G = NCFG.NierCFGFloat + +exception ParseError of string type t = { n_inputs : int; (* Number of inputs. *) @@ -32,6 +36,30 @@ type t = { } (* ONNX format handling. *) +type op_attribute = Oproto.Onnx.AttributeProto.t +type tensordata = Raw of bytes | Float of float list + +let (no_attr : op_attribute) = + { + name = None; + ref_attr_name = None; + doc_string = None; + type' = None; + f = None; + i = None; + s = None; + t = None; + g = None; + floats = []; + ints = []; + strings = []; + tensors = []; + graphs = []; + sparse_tensor = None; + tp = None; + sparse_tensors = []; + type_protos = []; + } let get_nested_dims (s : Oproto.Onnx.ValueInfoProto.t list) = match List.nth s 0 with @@ -64,6 +92,416 @@ let get_input_output_dim (model : Oprotom.t) = let output_flat_dim = flattened_dim output_shape in (input_flat_dim, output_flat_dim) +let produce_cfg (g : Oproto.Onnx.GraphProto.t) = + let open Oproto.Onnx in + let nodes = g.node + and inputs = g.input + and outputs = g.output + and initi = g.initializer' in + let fold_vip_names acc n = + match n.ValueInfoProto.name with + | Some str -> Some str :: acc + | None -> None :: acc + in + let i_nodes, o_nodes = + ( List.fold inputs ~init:[] ~f:fold_vip_names, + List.fold outputs ~init:[] ~f:fold_vip_names ) + and c_nodes = List.init (List.length nodes) ~f:(fun _ -> None) in + let fold_nodes_ops_cfg ns = + let get_node_operator_cfg x = + match x.NodeProto.op_type with + | None -> NCFG.Node.NO_OP + | Some o -> ( + match o with + | "Add" -> NCFG.Node.Add + | "Mul" -> NCFG.Node.Mul + | "Relu" -> NCFG.Node.ReLu + | "MatMul" -> NCFG.Node.Matmul + | "LogSoftmax" -> NCFG.Node.LogSoftmax + | "Transpose" -> NCFG.Node.Transpose + | "Squeeze" -> NCFG.Node.Squeeze + | "MaxPool" -> NCFG.Node.MaxPool + | "Conv" -> NCFG.Node.Conv + | "Identity" -> NCFG.Node.Identity + | _ -> + raise (ParseError ("Unsupported ONNX Operator in\n Parser: " ^ o))) + in + List.fold ~f:(fun acc n -> get_node_operator_cfg n :: acc) ~init:[] ns + in + let c_ops = List.rev @@ fold_nodes_ops_cfg nodes + and i_ops, o_ops = + ( List.init ~f:(fun _ -> NCFG.Node.NO_OP) (List.length i_nodes), + List.init ~f:(fun _ -> NCFG.Node.NO_OP) (List.length o_nodes) ) + in + let fold_nodes_attr ns = + let get_node_attr n = n.NodeProto.attribute in + List.fold ~f:(fun acc n -> get_node_attr n :: acc) ~init:[] ns + in + + let c_attr = List.rev @@ fold_nodes_attr nodes + and i_attr, o_attr = + ( List.init ~f:(fun _ -> [ no_attr ]) (List.length i_nodes), + List.init ~f:(fun _ -> [ no_attr ]) (List.length o_nodes) ) + in + let c_nodes_inputs, c_nodes_outputs = + List.unzip + @@ List.fold + ~f:(fun acc n -> (n.NodeProto.input, n.NodeProto.output) :: acc) + ~init:[] (List.rev nodes) + and i_nodes_inputs, i_nodes_outputs, o_nodes_inputs, o_nodes_outputs = + ( List.init ~f:(fun _ -> [ "NO_INPUT" ]) (List.length i_nodes), + List.init ~f:(fun _ -> [ "" ]) (List.length i_nodes), + List.init ~f:(fun _ -> [ "" ]) (List.length o_nodes), + List.init ~f:(fun _ -> [ "NO_OUTPUT" ]) (List.length o_nodes) ) + in + let data_dict = + let dict_tensors_cfg ts = + let get_float_from_index index data sh = + let index = Array.to_list index and sh = Array.to_list sh in + let pop_sh = List.tl_exn sh @ [ 1 ] in + (* Returns the factors by which multiply each coordinate *) + let rec get_factors_from_sh sh_f l = + match sh_f with + | [] -> List.rev l + | _ -> + get_factors_from_sh (List.tl_exn sh_f) + (List.fold ~f:(fun x y -> x * y) ~init:1 sh_f :: l) + in + let factors = get_factors_from_sh pop_sh [] in + let coord_in_data = + List.fold2_exn ~f:(fun x y z -> x + (y * z)) ~init:0 index factors + in + match data with + | Raw raw -> + let offset = 4 * coord_in_data in + (* Each float is coded on 4 bytes*) + let res = EndianBytes.LittleEndian.get_float raw offset in + res + | Float f -> List.nth_exn f coord_in_data + in + let build_tensor_from_data sh data = + let open NCFG.Tensor in + let sh = Array.of_list @@ sh in + let tensor = create sh K_float in + let coords = all_coords (get_shape tensor) in + let rec init_tensor t idx r = + match idx with + | x :: y -> + let value = + get_float_from_index (Array.of_list x) r (get_shape t) + in + set t (Array.of_list x) value; + init_tensor t y r + | [] -> t + in + init_tensor tensor coords data + in + let t_name x = + match x.TensorProto.name with Some n -> n | None -> "C_NODE" + in + let t_dim x = x.TensorProto.dims in + let t_data x = + match x.TensorProto.raw_data with + | Some rd -> Some (build_tensor_from_data (t_dim x) (Raw rd)) + | None -> ( + match x.TensorProto.float_data with + | [] -> None + | f -> Some (build_tensor_from_data (t_dim x) (Float f))) + in + List.fold + ~f:(fun m x -> Map.add_exn m ~key:(t_name x) ~data:(t_data x)) + ~init:(Map.empty (module String)) + ts + in + dict_tensors_cfg initi + in + let unpack v = + match v with + | Some v -> v + | None -> failwith "error, unpack found an unexpected None" + in + let tensor_list = + List.init + ~f:(fun i -> + match Map.find data_dict (unpack (List.nth_exn i_nodes i)) with + | Some v -> v + | None -> None) + (List.length i_nodes) + in + let tensor_list_full = Map.to_alist data_dict in + let tensor_list_rev = List.rev tensor_list in + let vip_dims v = + let val_t = + match v.ValueInfoProto.type' with + | Some t -> t + | None -> failwith "No type in value info" + in + let tns_t = + match val_t.TypeProto.value with + | `Tensor_type t -> t + | `not_set -> + failwith "No tensor type in value info" + (* TODO: support more tensor types *) + | _ -> raise (ParseError "Unknown tensor type.") + in + let tns_s = + match tns_t.shape with + | Some s -> s + | None -> failwith "No tensor shape in value info" + in + List.rev + @@ List.fold tns_s ~init:[] ~f:(fun acc d -> + match d.value with + | `Dim_value d -> d :: acc + | `not_set | _ -> 0 :: acc) + in + + let c_tensordim_list = List.init (List.length nodes) ~f:(fun _ -> []) + and c_tensorraw_list = List.init (List.length nodes) ~f:(fun _ -> None) + and o_tensordim_list = + List.fold ~f:(fun acc n -> vip_dims n :: acc) ~init:[] outputs + and o_tensorraw_list = List.init (List.length o_nodes) ~f:(fun _ -> None) + and i_tensordim_list = + List.fold ~f:(fun acc n -> vip_dims n :: acc) ~init:[] inputs + and i_tensorraw_list = tensor_list_rev in + let nodes_names = i_nodes @ c_nodes @ o_nodes in + let ops = i_ops @ c_ops @ o_ops in + let attrs = i_attr @ c_attr @ o_attr in + let prevs_list = i_nodes_inputs @ c_nodes_inputs @ o_nodes_inputs in + let nexts_list = i_nodes_outputs @ c_nodes_outputs @ o_nodes_outputs in + let tensor_dims = i_tensordim_list @ c_tensordim_list @ o_tensordim_list in + let tensors = i_tensorraw_list @ c_tensorraw_list @ o_tensorraw_list in + let operator_parameters (attr : AttributeProto.t list) op = + match op with + | NCFG.Node.Transpose -> + let ints_params = Array.of_list (List.nth_exn attr 0).ints in + Some (NCFG.Node.Transpose_params ints_params) + (*TODO: maxpool and conv operators: match attr.name in attributes to + * create the correct value for each attribute*) + (* | NCFG.Vertex.MaxPool -> *) + (* | NCFG.Vertex.Conv -> *) + | _ -> None + in + let rec build_op_param_list attrs ops l = + match (attrs, ops) with + | a :: b, c :: d -> build_op_param_list b d (operator_parameters a c :: l) + | [], [] -> + List.rev l + (*All other list constructions are folding right, so we need to put a + final revert *) + | _ -> + raise + (ParseError + "Error, operators and attributes list have not\n the same size") + in + let op_params_cfg = build_op_param_list attrs ops [] in + let cfg = G.init_cfg in + (* adding inputs, outputs and cnodes to the cfg *) + let unkerasize l = List.map ~f:(fun x -> if x = 0 then 1 else x) l in + for i = 0 to List.length nodes_names - 1 do + let (v : G.V.t) = + NCFG.Node.create ~id:i + ~name:(List.nth_exn nodes_names i) + ~sh:(Array.of_list @@ unkerasize (List.nth_exn tensor_dims i)) + ~op:(List.nth_exn ops i) + ~op_p:(List.nth_exn op_params_cfg i) + ~pred:(List.nth_exn prevs_list i) + ~succ:(List.nth_exn nexts_list i) + ~tensor:(List.nth_exn tensors i) + in + G.add_vertex cfg v + done; + (* Adding edges between vertices *) + (* For each unnamed vertex (= a calculation node) in the cfg ... *) + (* return true if l1 has at least one common element wih l2 *) + let rec shared_elm l1 l2 = + match l1 with + | x :: y -> List.mem l2 x ~equal:String.equal || shared_elm y l2 + | [] -> false + in + List.iter + ~f:(fun (v : G.V.t) -> + match v.name with + | None -> + let pred = v.pred and succ = v.succ in + let prev_v = + (* ... get all vertices in cfg that have the current vertex preds + * in their succ (at least one of their succ is inside our preds )*) + G.find_vertices cfg (fun (x : G.V.t) -> + if shared_elm pred x.succ then true else false) + (* ... get all vertices in cfg that have the current vertex preds + * in their name (they are named the same as one of our preds )*) + and named_pred = + G.find_vertices cfg (fun (x : G.V.t) -> + match x.name with + | Some name -> if shared_elm pred [ name ] then true else false + | None -> false) + (* ... get all vertices in cfg that have the current vertex succ + * in their name (they are named the same as one of our succs )*) + and named_succ = + G.find_vertices cfg (fun (x : G.V.t) -> + match x.name with + | Some name -> if shared_elm succ [ name ] then true else false + | None -> false) + (* get all vertices in cfg that have the current vertex succs + * in their preds (at least one of their preds is inside our succ )*) + and next_v = + G.find_vertices cfg (fun (x : G.V.t) -> + if shared_elm succ x.pred then true else false) + in + (* add edges between current vertex and identified preds and succs *) + let v_predecessors = prev_v @ named_pred + and v_successors = next_v @ named_succ in + let unpack_tname (x : G.V.t) = + match x.NCFG.Node.name with Some n -> n | None -> "" + in + List.iter + ~f:(fun (x : G.V.t) -> + let label = + match List.nth x.succ 0 with + | Some "NO_OUTPUT" -> + let pred_name = unpack_tname x in + if List.mem ~equal:String.equal v.NCFG.Node.pred pred_name + then pred_name + else "" + | Some l -> l + | None -> "" + in + G.add_edge_e cfg (x, label, v)) + v_predecessors; + (* add successors edges after filtering those + * that are already an edge*) + List.iter + ~f:(fun (x : G.V.t) -> + let all_preds = G.preds cfg x and all_succs = G.succs cfg x in + if List.mem ~equal:NCFG.Node.equal all_preds v + || List.mem ~equal:NCFG.Node.equal all_succs v + then () + else + let label = + match List.nth_exn x.pred 0 with + | "NO_INPUT" -> + let succ_name = unpack_tname x in + if List.mem ~equal:String.equal v.NCFG.Node.succ succ_name + then succ_name + else "" + | l -> l + in + G.add_edge_e cfg (v, label, x)) + v_successors + | _ -> ()) + (G.vertex_list cfg); + + (*rationale of the following: + * PyTorch stores network nodes in the field "inputs" of + * the ONNX graph g, and network parameters as a list of tensors + * in the ONNX initializer_. + * To make the two correspond, elements of g.inputs and g.initializer_ + * share the same value in the field "name". + * In Keras, elements of g.initializer_ have a name, but they do not + * correspond to any name in g.inputs. + * What we did before was then to create the actual nier cfg following the + * PyTorch way. + * Below, we complete the cfg with keras data by doing the following: + * * create a node for NIER for each tensor in onnx initializer_ + * * for each NIER node, check if there is a node sharing the same name + * pred + * * if yes, remove the one with highest ID (those are initi nodes, but since + * there is already a node in CFG with this name we do not + * need those) + * * if not, for each NIER node, chck if there is a node + * which name is contained in prevs. add it to the prev + * *) + + (* adding initi vertices to the cfg *) + for i = 0 to List.length tensor_list_full - 1 do + let shape = + match snd (List.nth_exn tensor_list_full i) with + | Some t -> unkerasize (Array.to_list @@ NCFG.Tensor.get_shape t) + | None -> [] + in + let (v : G.V.t) = + NCFG.Node.create + ~id:(i + List.length nodes_names) + ~name:(Some (fst (List.nth_exn tensor_list_full i))) + ~sh:(Array.of_list @@ unkerasize shape) + ~op:NO_OP ~op_p:None ~pred:[] ~succ:[] + ~tensor:(snd (List.nth_exn tensor_list_full i)) + in + G.add_vertex cfg v + done; + (* build a list of nodes + * sharing name but with different ids *) + let same_name_diff_ids = + let aux (x : G.V.t) = + G.fold_vertex + (fun y acc -> + match (x.name, y.name) with + | Some xa, Some ya -> + if (not (y.id = x.id)) && String.equal xa ya + then (x, y) :: acc + else acc + | _ -> acc) + cfg [] + in + G.fold_vertex (fun x l -> aux x :: l) cfg [] + in + let highest_ids = + List.fold + ~f:(fun acc x -> + match x with + | a :: _ -> + let maxval = max (fst a).NCFG.Node.id (snd a).NCFG.Node.id in + maxval :: acc + | [] -> acc) + ~init:[] same_name_diff_ids + in + (* (* removing nodes with highest id, those are the*) (* * ones we just added + *)*) + List.iter + ~f:(fun x -> + match x with + | l :: _ -> + let v1 = fst l in + if List.mem ~equal:( = ) highest_ids v1.NCFG.Node.id + then + (* Printf.printf "Removing id %d \n%!" *) + (* v1.NCFG.Vertex.id; *) + G.remove_vertex cfg v1 + else () + | [] -> ()) + same_name_diff_ids; + (* Now it is Keras time. + * Look for nodes sharing name and preds, + * then create edge *) + let shared_name_preds = + let aux (x : G.V.t) = + match x.name with + (* look in other vertices if name is among + * predecessors *) + | Some n -> G.find_vertices cfg (fun x -> shared_elm [ n ] x.pred) + | None -> [] + in + G.fold_vertex (fun x l -> (x, aux x) :: l) cfg [] + in + List.iter + ~f:(fun x -> + let orgn = fst x and to_edge = snd x in + List.iter + ~f:(fun t -> + if not (G.mem_edge cfg orgn t) + then G.add_edge_e cfg (orgn, unpack orgn.NCFG.Node.name, t) + else ()) + to_edge) + shared_name_preds; + (* else (); *) + cfg + +let nier_of_onnx_protoc (model : Oprotom.t) = + match model.graph with + | Some g -> produce_cfg g + | None -> raise (ParseError "No graph in ONNX input file!") + let parse_in_channel in_channel = let open Result in try @@ -72,7 +510,8 @@ let parse_in_channel in_channel = match Oprotom.from_proto reader with | Ok r -> let n_inputs, n_outputs = get_input_output_dim r in - Ok { n_inputs; n_outputs } + let nier = nier_of_onnx_protoc r in + Ok ({ n_inputs; n_outputs }, nier) | _ -> Error "Error parsing protobuf" with | Sys_error s -> Error s diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli index 776e3124dfe291e614ce0d09fa127cd195f111bd..7eb5500ccd950a59218ed5e7f3c875bbc1f923cf 100644 --- a/lib/onnx/onnx.mli +++ b/lib/onnx/onnx.mli @@ -20,11 +20,15 @@ (* *) (**************************************************************************) +module G = Ir.Nier_cfg.NierCFGFloat + type t = private { n_inputs : int; (** Number of inputs. *) n_outputs : int; (** Number of outputs. *) } (** ONNX model metadata. *) -val parse : string -> (t, string) Result.t -(** Parse an ONNX file. *) +(** Parse an ONNX file to get metadata for CAISAR as well as its inner + intermediate representation for the network. *) + +val parse : string -> (t * G.t, string) Result.t diff --git a/src/dune b/src/dune index 16397dcccc24f73f2946c29373a9cfeae4f9ff68..58539f690189404eac7ee90665a3beb13508d1fb 100644 --- a/src/dune +++ b/src/dune @@ -18,7 +18,8 @@ ovo why3 dune-site - re) + re + zarith) (preprocess (pps ppx_deriving_yojson diff --git a/src/language.ml b/src/language.ml index bbca3a69daad7882cc0352d155f8f66629ce7abe..6cae58d56283c06ad088a1d7f5ea71a416270b44 100644 --- a/src/language.ml +++ b/src/language.ml @@ -32,6 +32,7 @@ type nn_shape = { nb_outputs : int; ty_data : Ty.ty; filename : string; + nier : Onnx.G.t option; } type svm_shape = { nb_inputs : int; nb_classes : int; filename : string } @@ -41,7 +42,7 @@ let loaded_svms = Term.Hls.create 10 let lookup_loaded_nets = Term.Hls.find_opt loaded_nets let lookup_loaded_svms = Term.Hls.find_opt loaded_svms -let register_nn_as_tuple nb_inputs nb_outputs filename env = +let register_nn_as_tuple nb_inputs nb_outputs filename nier env = let net = Pmodule.read_module env [ "caisar" ] "NN" in let input_type = Ty.ty_app Theory.(ns_find_ts net.mod_theory.th_export [ "input_type" ]) [] @@ -57,7 +58,7 @@ let register_nn_as_tuple nb_inputs nb_outputs filename env = (Ty.ty_tuple (List.init nb_outputs ~f)) in Term.Hls.add loaded_nets ls_net_apply - { filename; nb_inputs; nb_outputs; ty_data = input_type }; + { filename; nb_inputs; nb_outputs; ty_data = input_type; nier }; let th_uc = Pmodule.add_pdecl ~vc:false th_uc (Pdecl.create_pure_decl (Decl.create_param_decl ls_net_apply)) @@ -86,13 +87,15 @@ let nnet_parser env _ filename _ = let model = Nnet.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env + | Ok model -> + register_nn_as_tuple model.n_inputs model.n_outputs filename None env let onnx_parser env _ filename _ = let model = Onnx.parse filename in match model with | Error s -> Loc.errorm "%s" s - | Ok model -> register_nn_as_tuple model.n_inputs model.n_outputs filename env + | Ok (model, nier) -> + register_nn_as_tuple model.n_inputs model.n_outputs filename (Some nier) env let ovo_parser env _ filename _ = let model = Ovo.parse filename in diff --git a/src/language.mli b/src/language.mli index 4d0a1ca12d2cefbae7dc71099599254ca09a9d94..90c8cc3a1e5e9c3a452d5f8dd5e02c43574a844a 100644 --- a/src/language.mli +++ b/src/language.mli @@ -27,6 +27,7 @@ type nn_shape = { nb_outputs : int; ty_data : Ty.ty; filename : string; + nier : Onnx.G.t option; } type svm_shape = { nb_inputs : int; nb_classes : int; filename : string } diff --git a/src/main.ml b/src/main.ml index 68d86548191017ec3e8a397b43e2451e9cd2f327..2931661c0cfa4e88284905399e673870126534a1 100644 --- a/src/main.ml +++ b/src/main.ml @@ -27,6 +27,7 @@ let caisar = "caisar" let () = Native_nn_prover.init (); + Actual_net_apply.init (); Vars_on_lhs.init () let () = @@ -229,3 +230,4 @@ let () = |> Cmd.eval ~catch:false |> Caml.exit with exn when not (log_level_is_debug ()) -> Logs.err (fun m -> m "@[%a@]" Why3.Exn_printer.exn_printer exn) + diff --git a/src/printers/marabou.ml b/src/printers/marabou.ml index 368bcb37bba3293bb8a4360025ba7aa11860ca7c..c237d726f6dadb162dca53210cd5a7cd3700de10 100644 --- a/src/printers/marabou.ml +++ b/src/printers/marabou.ml @@ -146,12 +146,12 @@ let rec print_tdecl info fmt task = print_tdecl info fmt task_prev; match task_decl.Theory.td_node with | Use _ | Clone _ -> () - | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_input + | Meta (meta, l) when Theory.meta_equal meta Utils.meta_input -> ( match l with | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i) | _ -> assert false) - | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_output + | Meta (meta, l) when Theory.meta_equal meta Utils.meta_output -> ( match l with | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i) diff --git a/src/printers/pyrat.ml b/src/printers/pyrat.ml index 59705230a5a71728a170c6a1c6ccb1dd280ff7e9..d7d772ab2ed0aaee9af027beafcf1254450a9fab 100644 --- a/src/printers/pyrat.ml +++ b/src/printers/pyrat.ml @@ -108,12 +108,12 @@ let rec print_tdecl info fmt task = print_tdecl info fmt task_prev; match task_decl.Theory.td_node with | Use _ | Clone _ -> () - | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_input + | Meta (meta, l) when Theory.meta_equal meta Utils.meta_input -> ( match l with | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "x%i" i) | _ -> assert false) - | Meta (meta, l) when Theory.meta_equal meta Native_nn_prover.meta_output + | Meta (meta, l) when Theory.meta_equal meta Utils.meta_output -> ( match l with | [ MAls ls; MAint i ] -> Term.Hls.add info.variables ls (Fmt.str "y%i" i) diff --git a/src/prover.ml b/src/prover.ml index 415a6d70375c03e9e55067dc1be1d0446fa6c6d9..08bd149f32c21fd7b741fa8acf4468050e1855e6 100644 --- a/src/prover.ml +++ b/src/prover.ml @@ -20,9 +20,9 @@ (* *) (**************************************************************************) -type t = Marabou | Pyrat | Saver +type t = Marabou | Pyrat | Saver | CVC5 -let list_available () = [ Marabou; Pyrat; Saver ] +let list_available () = [ Marabou; Pyrat; Saver ; CVC5] let of_string prover = let prover = String.lowercase_ascii prover in @@ -30,9 +30,11 @@ let of_string prover = | "marabou" -> Some Marabou | "pyrat" -> Some Pyrat | "saver" -> Some Saver + | "cvc5" -> Some CVC5 | _ -> None let to_string = function | Marabou -> "Marabou" | Pyrat -> "PyRAT" | Saver -> "SAVer" + | CVC5 -> "CVC5" diff --git a/src/prover.mli b/src/prover.mli index 2d0dba9fb28d59d7dbe18bc28be0ec90c19f2755..cf4ff02423623408a861905c0da8da1c9aebf2ac 100644 --- a/src/prover.mli +++ b/src/prover.mli @@ -20,7 +20,7 @@ (* *) (**************************************************************************) -type t = private Marabou | Pyrat | Saver +type t = private Marabou | Pyrat | Saver | CVC5 val list_available : unit -> t list val of_string : string -> t option diff --git a/src/transformations/actual_net_apply.ml b/src/transformations/actual_net_apply.ml new file mode 100644 index 0000000000000000000000000000000000000000..12418f37c6c649d6645a0418b8e36afc442c5b2f --- /dev/null +++ b/src/transformations/actual_net_apply.ml @@ -0,0 +1,391 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(**************************************************************************) + +open Why3 +open Base +open Utils +module IR = Ir.Nier_cfg +module G = Onnx.G + +exception UnsupportedOperator of string + +let vars = Term.Hvs.create 100 +let _lookup_vars = Term.Hvs.find_opt vars + +(* Import the proper theory according to the types of + * variables *) +let theory_for ty_vars env = + if Ty.ty_equal ty_vars Ty.ty_real + then Env.read_theory env [ "real" ] "RealInfix" + else Env.read_theory env [ "ieee_float" ] "Float64" + +let create_var pfx id ty vars = + let preid = Ident.id_fresh (pfx ^ Int.to_string id) in + let vsymbol = Term.create_vsymbol preid ty in + Term.Hvs.add vars vsymbol id; + vsymbol + +(* Conversion from tensor float data to a constant real term *) +let term_of_data d ty = + let d_s = Printf.sprintf "%h" d in + let rl = + if String.equal d_s "0x0p+0" + then Number.real_literal ~radix:16 ~neg:false ~int:"0" ~frac:"0" ~exp:None + else + (* Thanks to a similar function in Frama-C's WP plugin *) + let re_float = + Re__Pcre.regexp "-?0x([0-9a-f]+).([0-9a-f]+)?0*p?([+-]?[0-9a-f]+)?$" + in + let int, frac, exp = + match Re__Core.exec_opt re_float d_s with + | Some g -> ( + ( Re__Core.Group.get g 1, + Re__Core.Group.get g 2, + try Re__Core.Group.get g 3 with Caml.Not_found -> "" )) + | None -> failwith "Wrong assertion in fp regex." + in + let exp = if String.equal exp "" then None else Some exp in + let is_neg = Float.( <= ) d 0. in + Number.real_literal ~radix:16 ~neg:is_neg ~int ~frac ~exp + in + (* TODO: safe check whether the real value can be + * expressed in float. Fail if not, with an informative + * error message (eg: invalid float representation, maybe + * try this one with rounding?)*) + Term.t_const + (Constant.real_const ~pow2:rl.rl_real.rv_pow2 ~pow5:rl.rl_real.rv_pow5 + rl.rl_real.rv_sig) + ty + +(* Term describing the sum of two variables v1 and v2 with + * their proper types *) +let sum v1 v2 ty_vars env = + let theory = theory_for ty_vars env in + let plus_symbol = + if String.equal theory.th_name.id_string "RealInfix" + then Theory.ns_find_ls theory.Theory.th_export [ "infix +." ] + else Theory.ns_find_ls theory.Theory.th_export [ "infix .+" ] + in + Term.t_app_infer plus_symbol [ v1; v2 ] + +(* Term describing the mitiplication of two variables v1 and + * v2 with their proper types *) +let mul v1 v2 ty_vars env = + let theory = theory_for ty_vars env in + let mul_symbol = + if String.equal theory.th_name.id_string "RealInfix" + then Theory.ns_find_ls theory.Theory.th_export [ "infix *." ] + else Theory.ns_find_ls theory.Theory.th_export [ "infix .*" ] + in + Term.t_app_infer mul_symbol [ v1; v2 ] + +(* Bind variable v to term t in expression e *) +let bind v ~t ~e = Term.t_let_close v t e + +(* [id_on ~in_vars ~out_vars expr] creates a binding + * between two list of variables [in_vars] and [out_vars]. + * First variables of both list are binded on expression + * expr, each subsequent bindings are added on top of the + * resulting expression. + *) +let id_on ~in_vars ~out_vars expr = + if List.length in_vars <> List.length out_vars + then + failwith + "Error, expecting same amount of variables before declaring equality" + else + let eq_term = + List.foldi ~init:expr in_vars ~f:(fun i e in_var -> + bind (List.nth_exn out_vars i) ~t:(Term.t_var in_var) ~e) + in + eq_term + +(* [relu in_vars out_vars env expr ] creates terms defining + * the ReLU activation function application + * between two list of variables [in_vars] and + * [out_vars] on expression [expr]. + * First variables of both list are binded on expression + * expr, each subsequent bindings are added on top of the + * resulting expression. + * *) +let relu ~in_vars ~out_vars env expr = + if List.length in_vars <> List.length out_vars + then + failwith + "Error, expecting same amount of variables before declaring equality" + else + let relu_s = + let nn = Pmodule.read_module env [ "caisar" ] "NN" in + Theory.(ns_find_ls nn.mod_theory.th_export [ "relu" ]) + in + let eq_term = + List.foldi ~init:expr in_vars ~f:(fun i e in_var -> + let relu_on = Term.t_app_infer relu_s [ Term.t_var in_var ] in + bind (List.nth_exn out_vars i) ~t:relu_on ~e) + in + eq_term + +(* [etlw_sum in_vars out_vars data_node ty_vars env expr] + * creates terms defining the element-wise addition between + * two list of variables [in_vars] and [out_vars], + * a [data_node] holding the numerical + * value to add and an expression [expr]. + * First variables of both list are binded on expression + * expr, each subsequent bindings are added on top of the + * resulting expression. + * Assuming equal size between + * in_vars and out_vars, resulting term declares + * let out_vars[i] = in_vars + data[i] in ... *) +let eltw_sum ~in_vars ~out_vars data_node ty_vars env expr = + let data = + match IR.Node.get_tensor data_node with Some t -> t | None -> assert false + in + let data_vars = + List.map ~f:(fun v -> term_of_data v ty_vars) (IR.Tensor.flatten data) + in + match + List.fold2 ~init:(expr, 0) in_vars data_vars ~f:(fun (e, i) in_var d -> + ( bind (List.nth_exn out_vars i) + ~t:(sum (Term.t_var in_var) d ty_vars env) + ~e, + i + 1 )) + with + | List.Or_unequal_lengths.Ok (e, _) -> e + | List.Or_unequal_lengths.Unequal_lengths -> + failwith "Error in element-wise sum: incoherent list length." + +(* Expected semantic: + * Matrices A [n;m] B [m;p] C [n;p] + * C = AB + * C[i,j] = sum_k A[i;k] * B[k;j] + * + * [matmul in_vars out_vars data_node in_shape out_shape ty_vars env expr] + * creates terms defining the matrix multiplication between + * two list of variables in_vars, out_vars and a data_node. + * This function relies on the following assumptions: + * * in_vars represents the cells of matrix A (a_vars) + * * data stored in data_node is used to build the cells of matrix B (b_vars) + * * out_vars represents the cells of matrix C (c_vars) + * a_vars are the input variables + * b_vars the data variables + * c_vars the output variables + * c_vars[i,j] = sum_k a_vars[i,k] * b_vars[k,j] + * First variables of both list are binded on expression + * expr, each subsequent bindings are added on top of the + * resulting expression.*) +let matmul ~in_vars ~out_vars data_node ~in_shape ~out_shape ty_vars env expr = + let data = + match IR.Node.get_tensor data_node with Some t -> t | None -> assert false + in + let data_shape = IR.Node.get_shape data_node in + let data_vars = + List.map ~f:(fun v -> term_of_data v ty_vars) (IR.Tensor.flatten data) + in + let rec matmul_terms (i, j, t) ~c_var ~a_var ~b_var ~c_shape ~a_shape ~b_shape + = + match c_var with + | [] -> (i, j, t) + | x :: y -> + (* c[i,j] = sum_k a[i,k]*b[k,j]*) + (* a_var_range: all line of a *) + (* b_var_range: all column of b *) + (* TODO: be sure that the common dimension is indeed + * b_shape[0] *) + let k_dim = Array.get b_shape 0 in + let a_var_range = + List.init + ~f:(fun k -> + let idx = [| 0; 0; i; k |] in + IR.Tensor.get_flatnd_idx ~idx ~sh:a_shape a_var) + k_dim + and b_var_range = + List.init + ~f:(fun k -> + let idx = [| k; j |] in + let data = IR.Tensor.get data idx in + term_of_data data ty_vars) + k_dim + in + let muls = + match + List.map2 a_var_range b_var_range ~f:(fun a b -> + mul (Term.t_var a) b ty_vars env) + with + | List.Or_unequal_lengths.Ok l -> l + | List.Or_unequal_lengths.Unequal_lengths -> + failwith "Wrong inferred common dimension" + in + let new_term = + List.fold ~init:(term_of_data 0.0 ty_vars) + ~f:(fun term mul -> sum term mul ty_vars env) + muls + in + let new_term = bind x ~t:new_term ~e:t in + matmul_terms (i, j, new_term) ~c_var:y ~a_var ~b_var ~c_shape ~a_shape + ~b_shape + in + let _, _, terms = + matmul_terms (0, 0, expr) ~c_var:out_vars ~b_var:data_vars ~a_var:in_vars + ~c_shape:out_shape ~b_shape:data_shape ~a_shape:in_shape + in + terms + +let terms_of_nier g ty_inputs env ~net_output_vars ~net_input_vars = + IR.out_cfg_graph g; + (* Current NIER generation build the data nodes after the + * output variables, so we drop those since we will access + * those anyway later. *) + let vs = + let l = G.vertex_list g in + List.drop_while ~f:(fun n -> not (IR.Node.is_output_node n)) l + in + let _, expr = + (* Folding goes by decreasing id order, backward.*) + List.fold vs + ~init: + ( (net_output_vars, IR.Node.get_shape @@ List.nth_exn vs 0), + Term.t_tuple @@ List.map ~f:Term.t_var net_output_vars ) + ~f:(fun ((v_out_vars, out_shape), expr) v -> + let open IR in + let in_shape = + match Node.get_shape v with + | [||] -> G.infer_shape g v out_shape ~on_backward:true + | a -> a + in + let v_id = v.id in + let v_in_vars = + List.init + ~f:(fun i -> + create_var ("n_id_" ^ Int.to_string v_id ^ "_") i ty_inputs vars) + (List.length (Tensor.all_coords in_shape)) + in + let v_term = + (* TODO: axiomatize the resulting term using + * let d = Decl.create_prop_decl Paxiom ps t *) + match IR.Node.get_op v with + | Node.Matmul -> ( + match G.data_node_of v g with + | Some d_n -> + matmul ~in_vars:v_in_vars ~out_vars:v_out_vars d_n ~in_shape + ~out_shape ty_inputs env expr + | None -> failwith "Error, Matmul operator lacks a data node") + | Node.Add -> ( + match G.data_node_of v g with + | Some d_n -> + eltw_sum ~out_vars:v_out_vars ~in_vars:v_in_vars d_n ty_inputs env + expr + | None -> failwith "Error, Add operator lacks a data node") + | Node.NO_OP -> + (* If it is the input vertex, bind neural network + * input variables to the vertex output node. *) + if Node.is_input_node v + then + id_on ~out_vars:v_out_vars ~in_vars:net_input_vars expr + (* If it is the output vertex, the resulting + * term is the tuple of the output variables + * of the net; + * backpropagate those to the previous layer. *) + else if Node.is_output_node v + then + id_on ~out_vars:net_output_vars ~in_vars:v_in_vars + (Term.t_tuple @@ List.map ~f:Term.t_var net_output_vars) + else expr + | IR.Node.ReLu -> + relu ~out_vars:v_out_vars ~in_vars:v_in_vars env expr + | op -> + raise + (UnsupportedOperator + (Fmt.str "Operator %s is not implemented for actual_net_apply." + (IR.Node.str_op op))) + in + ((v_in_vars, out_shape), v_term)) + in + expr + +(* Create logic symbols for input variables and replace + * nnet_apply by control flow terms. + * Assumption for it to work: + * let (out1,out2,...) = net_apply (arg1,arg2,...) + * *) +let actual_nn_flow env = + let rec substitute_net_apply meta (term : Term.term) = + match term.t_node with + | Term.Tcase (t, [ tb ]) -> ( + let may_itg = + (* maybe input type graph *) + match t.t_node with + | Term.Tapp (ls, input_args) -> ( + match Language.lookup_loaded_nets ls with + | None -> None + | Some nn -> + meta := nn.filename :: !meta; + let g = + match nn.nier with + | Some g -> g + | None -> + failwith "Error, call this transform only on an ONNX NN." + in + let ty_inputs = nn.ty_data in + let net_input_vars = + List.map input_args ~f:(fun x -> + (*net_apply should always be followed by a + * non-empty list of arguments*) + match x.Term.t_node with Tvar ts -> ts | _ -> assert false) + in + Some (net_input_vars, ty_inputs, g)) + | _ -> None + in + let p, _ = Term.t_open_branch tb in + let may_o = + (* maybe output *) + match p.pat_node with + | Term.Papp (_, output_args) -> + Some + (List.map output_args ~f:(fun x -> + (* let (args1,args2,...) is the only + * supported application for now *) + match x.Term.pat_node with Pvar ts -> ts | _ -> assert false)) + | _ -> None + in + match (may_itg, may_o) with + | Some (net_input_vars, ty_inputs, g), Some net_output_vars -> + let cfg_term = + terms_of_nier g ty_inputs env ~net_output_vars ~net_input_vars + in + let t = Term.t_case cfg_term [ tb ] in + t + | _ -> Term.t_map (substitute_net_apply meta) term) + | _ -> Term.t_map (substitute_net_apply meta) term + in + Trans.fold + (fun task_hd task -> + match task_hd.task_decl.td_node with + | Use _ | Clone _ | Meta _ -> Task.add_tdecl task task_hd.task_decl + | Decl decl -> + let meta = ref [] in + let decl = + Decl.decl_map + (fun term -> + let term = substitute_net_apply meta term in + term) + decl + in + let task = + List.fold !meta ~init:task ~f:(fun task s -> + Task.add_meta task meta_nn_filename [ MAstr s ]) + in + Task.add_decl task decl) + None + +let actual_net_apply env = Trans.seq [ actual_nn_flow env ] + +let init () = + Trans.register_env_transform + ~desc: + "Transformation for provers that do not support direct loading\n\ + \ of neural networks. Replace the function declaration\n\ + \ net_apply by the actual control flow of the neural\n\ + \ network." "actual_net_apply" actual_net_apply diff --git a/src/transformations/actual_net_apply.mli b/src/transformations/actual_net_apply.mli new file mode 100644 index 0000000000000000000000000000000000000000..e45ca3e9317f1bd26d2c74f04dc67cb3e34baf60 --- /dev/null +++ b/src/transformations/actual_net_apply.mli @@ -0,0 +1,22 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(**************************************************************************) + +(** This module provide tools to convert a valid NIER into Whyml terms. + + NIER encapsulate parameters in tensor forms and a computation graph with + various operations. Whyml language supports multidimensional arrays as well. + + This module provide tools to properly translate NIER data into a list of + Whyml terms, describing the control flow of the neural network. Variables + are stored inside of an environment, their shape being either provided by + the NIER or inferred with the expected result of ONNX operations. *) + +open Why3 + +val actual_net_apply : Env.env -> Task.task Trans.trans + +val init : unit -> unit +(** Register the transformation. *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 59a0b962c6e5211e7ae3bd6c7d84c7c1b9ce2364..2ec6b2b3b147fb401610277805f33fee032b48fb 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -22,44 +22,7 @@ open Why3 open Base - -let meta_input = - Theory.( - register_meta "caisar_input" - ~desc:"Indicates the input position in the neural network" - [ MTlsymbol; MTint ]) - -let meta_output = - Theory.( - register_meta "caisar_output" - ~desc:"Indicates the output position in the neural network" - [ MTlsymbol; MTint ]) - -let meta_nn_filename = - Theory.( - register_meta_excl "caisar_nnet_or_onnx" - ~desc:"Indicates the filename of the network" [ MTstring ]) - -(* Retrieve the (input) variables appearing, as arguments, after an 'net_apply' - symbol. *) -let get_input_variables = - let rec aux acc (term : Term.term) = - match term.t_node with - | Term.Tapp (ls, args) -> ( - match Language.lookup_loaded_nets ls with - | None -> acc - | Some _ -> - let add i acc = function - | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc - | arg -> - invalid_arg - (Fmt.str "No direct variable in application: %a" Pretty.print_term - arg) - in - List.foldi ~init:acc ~f:add args) - | _ -> Term.t_fold aux acc term - in - Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty +open Utils (* Create logic symbols for output variables and simplify the formula. *) (* TODO: [Reduction_engine] is probably an overkill and should be replaced. *) diff --git a/src/transformations/native_nn_prover.mli b/src/transformations/native_nn_prover.mli index e211723bdd3b2faf9f1ddefd54790b258b7e6291..694097b3b0dcc9bef16af63c6f16994dd897e63e 100644 --- a/src/transformations/native_nn_prover.mli +++ b/src/transformations/native_nn_prover.mli @@ -20,16 +20,5 @@ (* *) (**************************************************************************) -open Why3 - val init : unit -> unit (** Register the transformation. *) - -val meta_input : Theory.meta -(** Indicate the input position. *) - -val meta_output : Theory.meta -(** Indicate the output position. *) - -val meta_nn_filename : Theory.meta -(** The filename of the nnet or onnx model. *) diff --git a/src/transformations/utils.ml b/src/transformations/utils.ml new file mode 100644 index 0000000000000000000000000000000000000000..562b423c11908377729c5e5b81f3597b77782d09 --- /dev/null +++ b/src/transformations/utils.ml @@ -0,0 +1,46 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +(* Retrieve the (input) variables appearing, as arguments, after an 'nnet_apply' + symbol. *) +let get_input_variables = + let rec aux acc (term : Term.term) = + match term.t_node with + | Term.Tapp (ls, args) -> ( + match Language.lookup_loaded_nets ls with + | None -> acc + | Some _ -> + let add i acc = function + | { Term.t_node = Tapp (vs, []); _ } -> Term.Mls.add vs i acc + | arg -> + invalid_arg + (Fmt.str "No direct variable in application: %a" Pretty.print_term + arg) + in + List.foldi ~init:acc ~f:add args) + | _ -> Term.t_fold aux acc term + in + Trans.fold_decl (fun decl acc -> Decl.decl_fold aux acc decl) Term.Mls.empty + +let meta_input = + Theory.( + register_meta "caisar_input" + ~desc:"Indicates the input position in the neural network" + [ MTlsymbol; MTint ]) + +let meta_output = + Theory.( + register_meta "caisar_output" + ~desc:"Indicates the output position in the neural network" + [ MTlsymbol; MTint ]) + +let meta_nn_filename = + Theory.( + register_meta_excl "caisar_nnet_or_onnx" + ~desc:"Indicates the filename of the network" [ MTstring ]) diff --git a/src/transformations/utils.mli b/src/transformations/utils.mli new file mode 100644 index 0000000000000000000000000000000000000000..ecfa4a031606614555a34788fe57e95cc7083033 --- /dev/null +++ b/src/transformations/utils.mli @@ -0,0 +1,19 @@ +(**************************************************************************) +(* *) +(* This file is part of CAISAR. *) +(* *) +(**************************************************************************) + +open Why3 +open Base + +val get_input_variables : int Term.Mls.t Trans.trans + +val meta_input : Theory.meta +(** Indicate the input position. *) + +val meta_output : Theory.meta +(** Indicate the output position. *) + +val meta_nn_filename : Theory.meta +(** The filename of the nnet or onnx model. *) diff --git a/src/verification.ml b/src/verification.ml index 8b7fef1ae22dbbd24d0a3c3ddc29628602982444..d0e5f1a54abd2971fecb2472cf4ec2d0f54efa2d 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -93,18 +93,19 @@ let answer_saver limit config env config_prover dataset_csv task = let additional_info = Fmt.str "(%d/%d)" answer.nb_proved answer.nb_total in (prover_answer, Some additional_info) -let answer_generic limit config prover config_prover driver task = +let answer_generic limit config prover config_prover driver task env = let task_prepared = Driver.prepare_task driver task in let tasks = - (* We make [tasks] as a list (ie, conjunction) of disjunctions. *) match prover with | Prover.Marabou -> Trans.apply Split.split_all task_prepared | Pyrat -> Trans.apply Split.split_premises task_prepared + | Prover.CVC5 -> + [ Trans.apply (Actual_net_apply.actual_net_apply env) task_prepared ] | _ -> [ task_prepared ] in let command = Whyconf.get_complete_command ~with_steps:false config_prover in let nn_file = - match Task.on_meta_excl Native_nn_prover.meta_nn_filename task_prepared with + match Task.on_meta_excl Utils.meta_nn_filename task_prepared with | Some [ MAstr nn_file ] -> nn_file | Some _ -> assert false (* By construction of the meta. *) | None -> invalid_arg "No neural network model found in task" @@ -128,8 +129,8 @@ let call_prover ~limit config env prover config_prover driver dataset_csv task = match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset_csv task - | Marabou | Pyrat -> - answer_generic limit config prover config_prover driver task + | Marabou | Pyrat | CVC5 -> + answer_generic limit config prover config_prover driver task env in Logs.app (fun m -> m "@[Goal %a:@ %a%a@]" Pretty.print_pr (Task.task_goal task) diff --git a/stdlib/caisar.mlw b/stdlib/caisar.mlw index 1eb5c6da5c7800551b058762ad9acb7118a9397e..764ec3b5db9c6592e7e6d48ba5a6e86059739574 100644 --- a/stdlib/caisar.mlw +++ b/stdlib/caisar.mlw @@ -21,8 +21,10 @@ (**************************************************************************) theory NN + (** Module defining commonly-met operations in neural networks. *) use ieee_float.Float64 type input_type = t + function relu (a: t) : t = if a .> (0.0:t) then a else (0.0:t) end theory Model diff --git a/tests/autodetect.t b/tests/autodetect.t index f184ddd6e2a9b06703e6b5b08f590555805411c1..153a29fd6231dcb9591e4fb4d82580b08b0db3eb 100644 --- a/tests/autodetect.t +++ b/tests/autodetect.t @@ -4,7 +4,7 @@ Test autodetect > echo "2.4.0" > EOF - $ chmod u+x bin/alt-ergo bin/pyrat.py bin/Marabou bin/saver + $ chmod u+x bin/alt-ergo bin/pyrat.py bin/Marabou bin/saver bin/cvc5 $ bin/alt-ergo 2.4.0 @@ -18,11 +18,15 @@ Test autodetect $ bin/saver --version v1.0 + $ bin/cvc5 --version + This is cvc5 version 1.0.2 [git tag 1.0.2 branch HEAD] + $ PATH=$(pwd)/bin:$PATH $ caisar config -d -vv 2>&1 | ./filter_tmpdir.sh [caisar][DEBUG] Command `config' [caisar][DEBUG] Automatic detection + <autodetect>Run: ($TESTCASE_ROOT/bin/cvc5 --version 2>&1 | head -1) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/pyrat.py --version) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/saver --version 2>&1 | head -n1 && (which saver > /dev/null 2>&1)) > $TMPFILE 2>&1 <autodetect>Run: ($TESTCASE_ROOT/bin/alt-ergo --version) > $TMPFILE 2>&1 @@ -30,11 +34,13 @@ Test autodetect <autodetect>0 prover(s) added <autodetect>Generating strategies: <autodetect>Found prover Alt-Ergo version 2.4.0, OK. + <autodetect>Found prover CVC5 version 1.0.2, OK. <autodetect>Found prover Marabou version 1.0.+, OK. <autodetect>Found prover PyRAT version 1.1, OK. <autodetect>Found prover SAVer version v1.0, OK. - <autodetect>4 prover(s) added + <autodetect>5 prover(s) added [caisar] Alt-Ergo 2.4.0 + CVC5 1.0.2 Marabou 1.0.+ PyRAT 1.1 SAVer v1.0 diff --git a/tests/bin/cvc5 b/tests/bin/cvc5 new file mode 100644 index 0000000000000000000000000000000000000000..20cc33fa9411e72238aabfd80131e17df4b6635e --- /dev/null +++ b/tests/bin/cvc5 @@ -0,0 +1,15 @@ +#!/bin/sh -e + + +case $1 in + --version) + echo "This is cvc5 version 1.0.2 [git tag 1.0.2 branch HEAD]" + ;; + *) + echo "PWD: $(pwd)" + echo "NN: $1" + test -e $1 || (echo "Cannot find the NN file" && exit 1) + echo "Goal:" + cat $2 + echo "Unknown" +esac diff --git a/tests/dune b/tests/dune index 90d0754de69c588e1b69167e33b396cfa398fc41..9bbc0b29d75c7181c8c349a5d72408663f65ad1c 100644 --- a/tests/dune +++ b/tests/dune @@ -7,5 +7,6 @@ bin/pyrat.py bin/Marabou bin/saver + bin/cvc5 filter_tmpdir.sh) (package caisar)) diff --git a/tests/simple_cvc5.t b/tests/simple_cvc5.t new file mode 100644 index 0000000000000000000000000000000000000000..5ae9b39395a3f2ed67efa8e0af20927d8fcba54e --- /dev/null +++ b/tests/simple_cvc5.t @@ -0,0 +1,35 @@ +Test verify + $ cat - > bin/alt-ergo << EOF + > #!/bin/sh + > echo "2.4.0" + > EOF + + $ chmod u+x bin/alt-ergo bin/pyrat.py bin/Marabou bin/saver bin/cvc5 + + $ bin/alt-ergo + 2.4.0 + + $ bin/pyrat.py --version + PyRAT 1.1 + + $ bin/Marabou --version + 1.0.+ + + $ bin/saver --version + v1.0 + + $ PATH=$(pwd)/bin:$PATH + + $ caisar verify -L . --format whyml --prover=CVC5 - 2>&1 <<EOF | sed 's/\/tmp\/[a-z0-9_./]*/$TMPFILE/' + > theory T + > use TestNetworkONNX.NNasTuple + > use ieee_float.Float64 + > use caisar.NN + > + > goal G: forall x1 x2 x3. + > (0.0:t) .< x1 .< (0.5:t) -> + > let (y1,y2) = net_apply x1 x2 x3 in + > (0.0:t) .< y1 .< (0.5:t) /\ (0.0:t) .< y2 .< (1.0:t) + > end + > EOF + [caisar] Goal G: High failure