diff --git a/lib/nir/node.ml b/lib/nir/node.ml
index f09d211f489c1bdb79699b1f8f56376854676f09..b162f2f0a17eb47df2c348d7a357057e3b994431 100644
--- a/lib/nir/node.ml
+++ b/lib/nir/node.ml
@@ -511,7 +511,7 @@ let partial_dot_product ?shp arr1 arr2 first last =
   let ioob str = failwith @@ "Index out of bound for arr" ^ str in
   if last > Array.length arr1 then ioob "1"
   else if last > Array.length arr2 then ioob "2"
-  else if last < first then 
+  else if last > first then
     let rec aux index acc = 
       if index = last then acc else 
       let acc = acc + (arr1.(index) * arr2.(index))
@@ -519,7 +519,7 @@ let partial_dot_product ?shp arr1 arr2 first last =
       aux Int.(index+1) acc
     in
     aux Int.(first+1) (arr1.(first) * arr2.(first))
-  else 
+  else (* nothing to include, returns a tensor of 0s. *)
     let actual_shape = 
       if Array.length arr1 <> 0 then 
         compute_shape arr1.(0)
@@ -530,21 +530,3 @@ let partial_dot_product ?shp arr1 arr2 first last =
       | None -> failwith "Cannot determine shape of tensor"
     in 
     create @@ (Constant { data = Gentensor.create_const_float actual_shape 0.0})
-    (*
-  if Array.length arr1 < last && Array.length arr2 < last then 
-    failwith "Index out of bound"
-  else
-  let zero_node =
-    create @@ Constant { data = Gentensor.create_1_float 0.0 }
-  in
-  let rec aux index acc =
-    if index = last
-    then acc
-    else
-      let prod = arr1.(index) * arr2.(index) in
-      let new_acc = acc + prod in
-      aux Int.(index + 1) new_acc
-  in
-  aux first zero_node
-*)
- 
\ No newline at end of file
diff --git a/lib/nir/node.mli b/lib/nir/node.mli
index 68918fe1a1e6c6571e21e8b9fc82e8299eb28afd..83f88c88a1eaad7675261a82dbb5ee1a711286c5 100644
--- a/lib/nir/node.mli
+++ b/lib/nir/node.mli
@@ -182,12 +182,12 @@ If [ns] is empty, this returns a tensor of shape [shp] filled with 0s.
 By default, [shp] is a single float. *)
 
 val partial_dot_product : ?shp:Shape.t -> t array -> t array -> int -> int -> t
-(** [partial_dot_product arr1 arr2 first last] 
+(** [partial_dot_product shp arr1 arr2 first last] 
 where [arr1 = [|n11,n12,...,n1k1|]] and [arr2 = [|n21,n22,...,n2k2|]]
 is a node corresponding to 
 [(n1first * n2first) + (n1first+1 * n2first+1) + ... + (n1last-1 * n2last-1)]
 if this exists.
-It is assumed that [arr1] and [arr2] contain tensors of similar shape.
+It is assumed that [arr1] and [arr2] contain tensors with same shape.
 Edge cases include: 
 
 {ul