Skip to content
Snippets Groups Projects
Commit b0002177 authored by Julien Girard-Satabin's avatar Julien Girard-Satabin Committed by Michele Alberti
Browse files

Basic structure for ONNX output through command line.

parent 8c2c18fa
No related branches found
No related tags found
No related merge requests found
......@@ -504,6 +504,46 @@ let nier_of_onnx_protoc (model : Oprotom.t) =
| Some g -> produce_cfg g
| None -> raise (ParseError "No graph in ONNX input file found")
let nier_to_onnx_protoc nier =
(* TODO: write a simple ONNX model from a dummy NIER *)
let vertices = G.vertex_list nier in
let protocs =
let vertex_to_protoc v =
let op_type = NCFG.Node.str_op (NCFG.Node.get_op v) in
let name = NCFG.Node.get_name v in
let domain = "" in
let input, output =
(NCFG.Node.get_pred_list v, NCFG.Node.get_succ_list v)
in
Oproto.Onnx.NodeProto.make ~input ~output ~name ~op_type ~domain
~attribute:[] ~doc_string:"" ()
in
List.fold ~init:[] ~f:(fun acc v -> vertex_to_protoc v :: acc) vertices
in
let protog =
Oproto.Onnx.GraphProto.make ~name:"" ~node:protocs ~initializer':[]
~sparse_initializer:[] ~doc_string:"" ~input:[] ~output:[] ~value_info:[]
~quantization_annotation:[] ()
in
let protom =
Oproto.Onnx.ModelProto.make ~ir_version:13 ~opset_import:[]
~producer_name:"CAISAR" ~producer_version:"1.0" ~domain:""
~model_version:(-1) ~doc_string:"" ~graph:protog ~metadata_props:[]
~training_info:[] ~functions:[] ()
in
let writer = Oprotom.to_proto protom in
Ocaml_protoc_plugin.Writer.contents writer
let write_nier_to_onnx _nier out_channel =
let nier = G.init_cfg in
let n =
Ir.Nier_cfg.Node.create ~id:0 ~name:None ~sh:[||] ~op:Ir.Nier_cfg.Node.NO_OP
~op_p:None ~pred:[] ~succ:[] ~tensor:None
in
G.add_vertex nier n;
let onnx = nier_to_onnx_protoc nier in
Stdio.Out_channel.output_string out_channel onnx
let parse_in_channel in_channel =
let open Result in
try
......@@ -528,3 +568,8 @@ let parse filename =
Fun.protect
~finally:(fun () -> Stdlib.close_in in_channel)
(fun () -> parse_in_channel in_channel)
let write nier filename =
let out_chan = Stdlib.open_out filename in
write_nier_to_onnx nier out_chan;
Stdlib.close_out out_chan
......@@ -30,4 +30,7 @@ type t = private {
(** ONNX model metadata and intermediate representation. *)
val parse : string -> (t, string) Result.t
(** Parse an ONNX file. *)
(** Parse an ONNX file into a NIER. *)
val write : G.t -> string -> unit
(** Write a NIER into an ONNX file. *)
......@@ -244,7 +244,8 @@ let create_nn_onnx =
Logs.warn (fun m ->
m "Cannot build network intermediate representation:@ %s" msg);
None
| Ok nier -> Some nier
| Ok nier ->
Some nier
in
{
nn_nb_inputs = n_inputs;
......
......@@ -123,14 +123,14 @@ let log_theory_answer =
additional_info)))
let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
?def_constants ?theories ?goals files =
?def_constants ?theories ?goals ?onnx_out_file files =
let memlimit = Option.map memlimit ~f:memlimit_of_string in
let timelimit = Option.map timelimit ~f:timelimit_of_string in
let theory_answers =
List.map files
~f:
(Verification.verify ?format ~loadpath ?memlimit ?timelimit ?dataset
prover ?prover_altern ?def_constants ?theories ?goals)
prover ?prover_altern ?def_constants ?theories ?goals ?onnx_out_file)
in
List.iter theory_answers ~f:log_theory_answer;
theory_answers
......@@ -256,6 +256,10 @@ let verify_cmd =
let doc = "Dataset $(docv) (CSV format only)." in
Arg.(value & opt (some file) None & info [ "dataset" ] ~doc ~docv:"FILE")
in
let onnx_out_file =
let doc = "Path where to save the ONNX outputs from NIER." in
Arg.(value & opt (some string) None & info [ "onnx-out-file" ] ~doc)
in
let define_constants =
let doc = "Define a declared constant with the given value." in
Arg.(
......@@ -295,16 +299,17 @@ let verify_cmd =
in
let verify_term =
let verify format loadpath memlimit timelimit prover prover_altern dataset
def_constants theories goals files () =
def_constants theories goals onnx_out_file files () =
ignore
(verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover
?prover_altern ~def_constants ~theories ~goals files)
?prover_altern ~def_constants ~theories ~goals ?onnx_out_file files)
in
Term.(
const (fun _ -> exec_cmd cmdname)
$ setup_logs
$ (const verify $ format $ loadpath $ memlimit $ timelimit $ prover
$ prover_altern $ dataset $ define_constants $ theories $ goals $ files))
$ prover_altern $ dataset $ define_constants $ theories $ goals
$ onnx_out_file $ files))
in
Cmd.v info verify_term
......
......@@ -243,8 +243,8 @@ let answer_dataset limit config env prover config_prover driver dataset task =
in
(prover_answer, additional_info)
let answer_generic limit config prover config_prover driver ~proof_strategy task
=
let answer_generic limit config prover config_prover driver ~proof_strategy
onnx_out_file task =
let tasks = proof_strategy task in
let answers =
List.concat_map tasks ~f:(fun task ->
......@@ -255,6 +255,33 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
| Some _ -> assert false (* By construction of the meta. *)
| None -> invalid_arg "No neural network model found in task"
in
let () =
let save_onnx_in ls f =
match Language.lookup_nn ls with
| Some { nn_nier = Some g; _ } -> (
try
Onnx.write g f;
Logs.info (fun m -> m "@[Wrote ONNX file at '%s'@]" f)
with Sys_error msg ->
Logs.err (fun m ->
m "@[System error: tried to write ONNX file a '%s', got '%s'@]"
f msg))
| None -> ()
| _ -> ()
in
match onnx_out_file with
| Some f ->
Task.task_iter
(fun decl ->
match decl.td_node with
| Use _ | Clone _ | Meta _ -> ()
| Decl decl -> (
match decl.d_node with
| Dparam ls -> save_onnx_in ls f
| _ -> ()))
task
| _ -> ()
in
let tasks =
(* Turn [task] into a list (ie, conjunction) of disjunctions of
tasks. *)
......@@ -271,7 +298,7 @@ let answer_generic limit config prover config_prover driver ~proof_strategy task
(prover_answer, additional_info)
let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
def_constants task =
def_constants onnx_out_file task =
let prover_answer, additional_info =
match prover with
| Prover.Saver -> answer_saver limit config env config_prover dataset task
......@@ -283,12 +310,13 @@ let call_prover ~cwd ~limit config env prover config_prover driver ?dataset
let task = Interpretation.interpret_task ~cwd env ~def_constants task in
let proof_strategy = Proof_strategy.apply_native_nn_prover in
answer_generic limit config prover config_prover driver ~proof_strategy
task
onnx_out_file task
| CVC5 ->
let task = Interpretation.interpret_task ~cwd env ~def_constants task in
let proof_strategy = Proof_strategy.apply_classic_prover env in
(* Not outputting ONNX for CVC5 as it does not handle such format*)
answer_generic limit config prover config_prover driver ~proof_strategy
task
None task
in
let id = Task.task_goal task in
{ id; prover_answer; additional_info }
......@@ -328,7 +356,7 @@ let tasks_of_theory ~theories ~goals theory =
List.exists goals_of_theory ~f:(String.equal task_goal_id))
let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
?(def_constants = []) ?(theories = []) ?(goals = []) file =
?(def_constants = []) ?(theories = []) ?(goals = []) ?onnx_out_file file =
let debug = Logging.(is_debug_level src_prover_call) in
(if debug then Debug.(set_flag (lookup_flag "call_prover")));
let env, config = create_env loadpath in
......@@ -391,6 +419,6 @@ let verify ?format ~loadpath ?memlimit ?timelimit ?dataset prover ?prover_altern
List.map
~f:
(call_prover ~cwd ~limit main env prover config_prover driver ?dataset
def_constants)
def_constants onnx_out_file)
tasks)
mstr_theory
......@@ -55,6 +55,7 @@ val verify :
?def_constants:(string * string) list ->
?theories:string list ->
?goals:(string * string list) list ->
?onnx_out_file: string ->
File.t ->
verification_result
(** Starts a verification of the given [file] with the provided [prover].
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment