diff --git a/lib/ir/nier_simple.ml b/lib/ir/nier_simple.ml index f793677d5dd52787587052ac3ac8c5dbb672871c..eba5bd2c92826df32a2185105d756f92d1cd0d60 100644 --- a/lib/ir/nier_simple.ml +++ b/lib/ir/nier_simple.ml @@ -457,6 +457,24 @@ module Node = struct let result = create (Matmul { input1; input2 }) in reshape (Shape.of_array [| 1 |]) result + let div_float ?(encode = true) input f = + if encode + then + let f = Float.one /. f in + mul_float input f + else + let input1 = reshape (Shape.of_array [| 1; 1 |]) input in + let f = Array.create ~len:1 f in + let f = + GenTensor.Float + (Tensor.of_array1 + (Shape.of_array [| Array.length f; 1 |]) + (Bigarray.Array1.of_array Float64 C_layout f)) + in + let input2 = create (Constant { data = f }) in + let result = create (Div { input1; input2 }) in + reshape (Shape.of_array [| 1 |]) result + let concat_0 = function | [ n ] -> n | [] -> failwith "empty concat" diff --git a/lib/ir/nier_simple.mli b/lib/ir/nier_simple.mli index 494f28cad7fce3a404d30f403e5d1c9e0d941eea..c928e3090854ac94dfdae4804089e17ec752b7ba 100644 --- a/lib/ir/nier_simple.mli +++ b/lib/ir/nier_simple.mli @@ -151,10 +151,17 @@ module Node : sig val gather_int : ?encode:bool -> t -> int -> t (** create a node by selection at a given index. *) - (* Implemented via a [Matmul] if [encode]. TODO: [encode] should be not be a - parameter, rather depend on prover. *) + (* Implemented via a [Matmul] if [encode] (true by default). + + TODO: [encode] should be not be a parameter, rather depend on prover. *) val mul_float : t -> float -> t + (* Implemented via a [Matmul]. *) + + val div_float : ?encode:bool -> t -> float -> t + (* Implemented via a [Matmul] if [encode] (true by default). + + TODO: [encode] should be not be a parameter, rather depend on prover. *) val constant_int_array : int array -> t (** create a node for a constant array *) diff --git a/src/transformations/native_nn_prover.ml b/src/transformations/native_nn_prover.ml index 844ee714fc4af220146fcdf1557be8da71f76c28..a119a2fd665f501aec8e5b90d5c7820619f23b88 100644 --- a/src/transformations/native_nn_prover.ml +++ b/src/transformations/native_nn_prover.ml @@ -130,8 +130,8 @@ let create_new_nn env input_vars outputs : string = | Tapp (ls, [ _; a; b ]) when Why3.Term.ls_equal ls th_f64.div -> ( match b.t_node with | Tconst (Why3.Constant.ConstReal r) -> - let f = Float.one /. Utils.float_of_real_constant r in - Ir.Nier_simple.Node.mul_float (convert_term a) f + let f = Utils.float_of_real_constant r in + Ir.Nier_simple.Node.div_float (convert_term a) f | _ -> IR.Node.create (Div { input1 = convert_term a; input2 = convert_term b }))