diff --git a/lib/nir/node.ml b/lib/nir/node.ml index f09d211f489c1bdb79699b1f8f56376854676f09..b162f2f0a17eb47df2c348d7a357057e3b994431 100644 --- a/lib/nir/node.ml +++ b/lib/nir/node.ml @@ -511,7 +511,7 @@ let partial_dot_product ?shp arr1 arr2 first last = let ioob str = failwith @@ "Index out of bound for arr" ^ str in if last > Array.length arr1 then ioob "1" else if last > Array.length arr2 then ioob "2" - else if last < first then + else if last > first then let rec aux index acc = if index = last then acc else let acc = acc + (arr1.(index) * arr2.(index)) @@ -519,7 +519,7 @@ let partial_dot_product ?shp arr1 arr2 first last = aux Int.(index+1) acc in aux Int.(first+1) (arr1.(first) * arr2.(first)) - else + else (* nothing to include, returns a tensor of 0s. *) let actual_shape = if Array.length arr1 <> 0 then compute_shape arr1.(0) @@ -530,21 +530,3 @@ let partial_dot_product ?shp arr1 arr2 first last = | None -> failwith "Cannot determine shape of tensor" in create @@ (Constant { data = Gentensor.create_const_float actual_shape 0.0}) - (* - if Array.length arr1 < last && Array.length arr2 < last then - failwith "Index out of bound" - else - let zero_node = - create @@ Constant { data = Gentensor.create_1_float 0.0 } - in - let rec aux index acc = - if index = last - then acc - else - let prod = arr1.(index) * arr2.(index) in - let new_acc = acc + prod in - aux Int.(index + 1) new_acc - in - aux first zero_node -*) - \ No newline at end of file diff --git a/lib/nir/node.mli b/lib/nir/node.mli index 68918fe1a1e6c6571e21e8b9fc82e8299eb28afd..83f88c88a1eaad7675261a82dbb5ee1a711286c5 100644 --- a/lib/nir/node.mli +++ b/lib/nir/node.mli @@ -182,12 +182,12 @@ If [ns] is empty, this returns a tensor of shape [shp] filled with 0s. By default, [shp] is a single float. *) val partial_dot_product : ?shp:Shape.t -> t array -> t array -> int -> int -> t -(** [partial_dot_product arr1 arr2 first last] +(** [partial_dot_product shp arr1 arr2 first last] where [arr1 = [|n11,n12,...,n1k1|]] and [arr2 = [|n21,n22,...,n2k2|]] is a node corresponding to [(n1first * n2first) + (n1first+1 * n2first+1) + ... + (n1last-1 * n2last-1)] if this exists. -It is assumed that [arr1] and [arr2] contain tensors of similar shape. +It is assumed that [arr1] and [arr2] contain tensors with same shape. Edge cases include: {ul