diff --git a/lib/onnx/reader.ml b/lib/onnx/reader.ml index ed5c87f18032e1fe3ca78e0a9b88da0d81096076..b110d3d62048e1ed7763887b909ae31725608143 100644 --- a/lib/onnx/reader.ml +++ b/lib/onnx/reader.ml @@ -189,37 +189,50 @@ end = struct (module String) (List.map ~f:(fun a -> (Option.value_exn a.name, a)) n.attribute) in - let get_float name : float = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ } - -> - f - | _ -> failwith "Attribute wrongly typed" + let get_attr ?default name m = + match Hashtbl.find attrs name with + | Some v -> m v + | None -> ( + match default with + | Some v -> v + | None -> Fmt.failwith "Required attribute %s missing" name) in - let get_int name : int = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } -> - Int64.to_int_exn i - | _ -> failwith "Attribute wrongly typed" + let get_float ?default name : float = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.FLOAT; f = Some f; _ } + -> + f + | _ -> failwith "Attribute wrongly typed") in - let get_ints name : int list = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } -> - List.map ~f:Int64.to_int_exn l - | _ -> failwith "Attribute wrongly typed" + let get_int ?default name : int = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } + -> + Int64.to_int_exn i + | _ -> failwith "Attribute wrongly typed") in - let get_bool name : bool = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } -> - not (Int64.equal i 0L) - | _ -> failwith "Attribute wrongly typed" + let get_ints ?default name : int list = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.INTS; ints = l; _ } -> + List.map ~f:Int64.to_int_exn l + | _ -> failwith "Attribute wrongly typed") in - let get_tensor name : Nir.Gentensor.t = - match Hashtbl.find_exn attrs name with - | { type' = Some AttributeProto.AttributeType.TENSOR; t = Some t; _ } - -> - convert_tensor t - | _ -> failwith "Attribute wrongly typed" + let get_bool ?default name : bool = + get_attr ?default name (function + | { type' = Some AttributeProto.AttributeType.INT; i = Some i; _ } + -> + not (Int64.equal i 0L) + | _ -> failwith "Attribute wrongly typed") + in + let get_tensor ?default name : Nir.Gentensor.t = + get_attr ?default name (function + | { + type' = Some AttributeProto.AttributeType.TENSOR; + t = Some t; + _; + } -> + convert_tensor t + | _ -> failwith "Attribute wrongly typed") in let n' = match n.op_type with @@ -257,10 +270,10 @@ end = struct inputA = convert inputA; inputB = convert inputB; inputC = Option.map ~f:convert inputC; - alpha = get_float "alpha"; - beta = get_float "beta"; - transA = get_bool "transA"; - transB = get_bool "transB"; + alpha = get_float ~default:1.0 "alpha"; + beta = get_float ~default:1.0 "beta"; + transA = get_bool ~default:false "transA"; + transB = get_bool ~default:false "transB"; } | "LogSoftmax" -> Nir.Node.LogSoftmax | "Transpose" ->