From b0002177e57a63f2c6353b47bcaf5a53e50f9603 Mon Sep 17 00:00:00 2001 From: Julien Girard <julien.girard2@cea.fr> Date: Thu, 30 Nov 2023 11:03:57 +0100 Subject: [PATCH] Basic structure for ONNX output through command line. --- lib/onnx/onnx.ml | 45 ++++++++++++++++++++++++++++++++++++++++++++ lib/onnx/onnx.mli | 5 ++++- src/language.ml | 3 ++- src/main.ml | 15 ++++++++++----- src/verification.ml | 42 ++++++++++++++++++++++++++++++++++------- src/verification.mli | 1 + 6 files changed, 97 insertions(+), 14 deletions(-) diff --git a/lib/onnx/onnx.ml b/lib/onnx/onnx.ml index 355aa43..a3b631e 100644 --- a/lib/onnx/onnx.ml +++ b/lib/onnx/onnx.ml @@ -504,6 +504,46 @@ let nier_of_onnx_protoc (model : Oprotom.t) = | Some g -> produce_cfg g | None -> raise (ParseError "No graph in ONNX input file found") +let nier_to_onnx_protoc nier = + (* TODO: write a simple ONNX model from a dummy NIER *) + let vertices = G.vertex_list nier in + let protocs = + let vertex_to_protoc v = + let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in + let name = NCFG.Node.get_name v in + let domain = "" in + let input, output = + (NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v) + in + Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain + ~attribute:[] ~doc_string:"" () + in + List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices + in + let protog = + Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[] + ~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[] + ~quantization_annotation:[] () + in + let protom = + Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[] + ~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:"" + ~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[] + ~training_info:[] ~functions:[] () + in + let writer = Oprotom.to_proto protom in + Ocaml_protoc_plugin.Writer.contents writer + +let write_nier_to_onnx _nier out_channel = + let nier = G.init_cfg in + let n = + Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP + ~op_p:None ~pred:[] ~succ:[] ~tensor:None + in + G.add_vertex nier n; + let onnx = nier_to_onnx_protoc nier in + Stdio.Out_channel.output_string out_channel onnx + let parse_in_channel in_channel = let open Result in try @@ -528,3 +568,8 @@ let parse filename = Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) (fun () -> parse_in_channel in_channel) + +let write nier filename = + let out_chan = Stdlib.open_out filename in + write_nier_to_onnx nier out_chan; + Stdlib.close_out out_chan diff --git a/lib/onnx/onnx.mli b/lib/onnx/onnx.mli index e8d102c..0011ac1 100644 --- a/lib/onnx/onnx.mli +++ b/lib/onnx/onnx.mli @@ -30,4 +30,7 @@ type t = private { (** ONNX model metadata and intermediate representation. *) val parse : string -> (t, string) Result.t -(** Parse an ONNX file. *) +(** Parse an ONNX file into a NIER. *) + +val write : G.t -> string -> unit +(** Write a NIER into an ONNX file. *) diff --git a/src/language.ml b/src/language.ml index 6d9a186..09c4813 100644 --- a/src/language.ml +++ b/src/language.ml @@ -244,7 +244,8 @@ let create_nn_onnx = Logs.warn (fun m -> m "Cannot build network intermediate representation:@ %s" msg); None - | Ok nier -> Some nier + | Ok nier -> + Some nier in { nn_nb_inputs = n_inputs; diff --git a/src/main.ml b/src/main.ml index 73bcadc..e1b74da 100644 --- a/src/main.ml +++ b/src/main.ml @@ -123,14 +123,14 @@ let log_theory_answer = additional_info))) let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern - ?def_constants ?theories ?goals files = + ?def_constants ?theories ?goals ?onnx_out_file files = let memlimit = Option.map memlimit ~f:memlimit_of_string in let timelimit = Option.map timelimit ~f:timelimit_of_string in let theory_answers = List.map files ~f: (Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset - prover ?prover_altern ?def_constants ?theories ?goals) + prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_file) in List.iter theory_answers ~f:log_theory_answer; theory_answers @@ -256,6 +256,10 @@ let verify_cmd = let doc = "Dataset $(docv) (CSV format only)." in Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE") in + let onnx_out_file = + let doc = "Path where to save the ONNX outputs from NIER." in + Arg.(value & opt (some string) None & info [ "onnx-out-file" ] ~doc) + in let define_constants = let doc = "Define a declared constant with the given value." in Arg.( @@ -295,16 +299,17 @@ let verify_cmd = in let verify_term = let verify format loadpath memlimit timelimit prover prover_altern dataset - def_constants theories goals files () = + def_constants theories goals onnx_out_file files () = ignore (verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover - ?prover_altern ~def_constants ~theories ~goals files) + ?prover_altern ~def_constants ~theories ~goals ?onnx_out_file files) in Term.( const (fun _ -> exec_cmd cmdname) $ setup_logs $ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover - $ prover_altern $ dataset $ define_constants $ theories $ goals $ files)) + $ prover_altern $ dataset $ define_constants $ theories $ goals + $ onnx_out_file $ files)) in Cmd.v info verify_term diff --git a/src/verification.ml b/src/verification.ml index e1ba5bf..e59c1bc 100644 --- a/src/verification.ml +++ b/src/verification.ml @@ -243,8 +243,8 @@ let answer_dataset limit config env prover config_prover driver dataset task = in (prover_answer, additional_info) -let answer_generic limit config prover config_prover driver ~proof_strategy task - = +let answer_generic limit config prover config_prover driver ~proof_strategy + onnx_out_file task = let tasks = proof_strategy task in let answers = List.concat_map tasks ~f:(fun task -> @@ -255,6 +255,33 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task | Some _ -> assert false (* By construction of the meta. *) | None -> invalid_arg "No neural network model found in task" in + let () = + let save_onnx_in ls f = + match Language.lookup_nn ls with + | Some { nn_nier = Some g; _ } -> ( + try + Onnx.write g f; + Logs.info (fun m -> m "@[Wrote ONNX file at '%s'@]" f) + with Sys_error msg -> + Logs.err (fun m -> + m "@[System error: tried to write ONNX file a '%s', got '%s'@]" + f msg)) + | None -> () + | _ -> () + in + match onnx_out_file with + | Some f -> + Task.task_iter + (fun decl -> + match decl.td_node with + | Use _ | Clone _ | Meta _ -> () + | Decl decl -> ( + match decl.d_node with + | Dparam ls -> save_onnx_in ls f + | _ -> ())) + task + | _ -> () + in let tasks = (* Turn [task] into a list (ie, conjunction) of disjunctions of tasks. *) @@ -271,7 +298,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task (prover_answer, additional_info) let call_prover ~cwd ~limit config env prover config_prover driver ?dataset - def_constants task = + def_constants onnx_out_file task = let prover_answer, additional_info = match prover with | Prover.Saver -> answer_saver limit config env config_prover dataset task @@ -283,12 +310,13 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset let task = Interpretation.interpret_task ~cwd env ~def_constants task in let proof_strategy = Proof_strategy.apply_native_nn_prover in answer_generic limit config prover config_prover driver ~proof_strategy - task + onnx_out_file task | CVC5 -> let task = Interpretation.interpret_task ~cwd env ~def_constants task in let proof_strategy = Proof_strategy.apply_classic_prover env in + (* Not outputting ONNX for CVC5 as it does not handle such format*) answer_generic limit config prover config_prover driver ~proof_strategy - task + None task in let id = Task.task_goal task in { id; prover_answer; additional_info } @@ -328,7 +356,7 @@ let tasks_of_theory ~theories ~goals theory = List.exists goals_of_theory ~f:(String.equal task_goal_id)) let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern - ?(def_constants = []) ?(theories = []) ?(goals = []) file = + ?(def_constants = []) ?(theories = []) ?(goals = []) ?onnx_out_file file = let debug = Logging.(is_debug_level src_prover_call) in (if debug then Debug.(set_flag (lookup_flag "call_prover"))); let env, config = create_env loadpath in @@ -391,6 +419,6 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern List.map ~f: (call_prover ~cwd ~limit main env prover config_prover driver ?dataset - def_constants) + def_constants onnx_out_file) tasks) mstr_theory diff --git a/src/verification.mli b/src/verification.mli index 92bbcd3..87c9043 100644 --- a/src/verification.mli +++ b/src/verification.mli @@ -55,6 +55,7 @@ val verify : ?def_constants:(string * string) list -> ?theories:string list -> ?goals:(string * string list) list -> + ?onnx_out_file: string -> File.t -> verification_result (** Starts a verification of the given [file] with the provided [prover]. -- GitLab