Skip to content
Snippets Groups Projects
Commit 2c3af7f6 authored by Aymeric Varasse's avatar Aymeric Varasse :innocent:
Browse files

[onnx] Update reader utils

parent 9ce419ce
No related branches found
No related tags found
No related merge requests found
......@@ -189,37 +189,50 @@ end = struct
(module String)
(List.map ~f:(fun a -> (Option.value_exn a.name, a)) n.attribute)
in
let get_float name : float =
match Hashtbl.find_exn attrs name with
| { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ }
->
f
| _ -> failwith "Attribute wrongly typed"
let get_attr ?default name m =
match Hashtbl.find attrs name with
| Some v -> m v
| None -> (
match default with
| Some v -> v
| None -> Fmt.failwith "Required attribute %s missing" name)
in
let get_int name : int =
match Hashtbl.find_exn attrs name with
| { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } ->
Int64.to_int_exn i
| _ -> failwith "Attribute wrongly typed"
let get_float ?default name : float =
get_attr ?default name (function
| { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ }
->
f
| _ -> failwith "Attribute wrongly typed")
in
let get_ints name : int list =
match Hashtbl.find_exn attrs name with
| { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } ->
List.map ~f:Int64.to_int_exn l
| _ -> failwith "Attribute wrongly typed"
let get_int ?default name : int =
get_attr ?default name (function
| { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ }
->
Int64.to_int_exn i
| _ -> failwith "Attribute wrongly typed")
in
let get_bool name : bool =
match Hashtbl.find_exn attrs name with
| { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } ->
not (Int64.equal i 0L)
| _ -> failwith "Attribute wrongly typed"
let get_ints ?default name : int list =
get_attr ?default name (function
| { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } ->
List.map ~f:Int64.to_int_exn l
| _ -> failwith "Attribute wrongly typed")
in
let get_tensor name : Nir.Gentensor.t =
match Hashtbl.find_exn attrs name with
| { type' = Some AttributeProto.AttributeType.TENSOR; t = Some t; _ }
->
convert_tensor t
| _ -> failwith "Attribute wrongly typed"
let get_bool ?default name : bool =
get_attr ?default name (function
| { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ }
->
not (Int64.equal i 0L)
| _ -> failwith "Attribute wrongly typed")
in
let get_tensor ?default name : Nir.Gentensor.t =
get_attr ?default name (function
| {
type' = Some AttributeProto.AttributeType.TENSOR;
t = Some t;
_;
} ->
convert_tensor t
| _ -> failwith "Attribute wrongly typed")
in
let n' =
match n.op_type with
......@@ -257,10 +270,10 @@ end = struct
inputA = convert inputA;
inputB = convert inputB;
inputC = Option.map ~f:convert inputC;
alpha = get_float "alpha";
beta = get_float "beta";
transA = get_bool "transA";
transB = get_bool "transB";
alpha = get_float ~default:1.0 "alpha";
beta = get_float ~default:1.0 "beta";
transA = get_bool ~default:false "transA";
transB = get_bool ~default:false "transB";
}
| "LogSoftmax" -> Nir.Node.LogSoftmax
| "Transpose" ->
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment