diff --git a/com.unity.ml-agents/Runtime/Inference/TensorExtensions.cs b/com.unity.ml-agents/Runtime/Inference/TensorExtensions.cs index 9e358cf157..94a717eea5 100644 --- a/com.unity.ml-agents/Runtime/Inference/TensorExtensions.cs +++ b/com.unity.ml-agents/Runtime/Inference/TensorExtensions.cs @@ -57,9 +57,9 @@ public static int Index(this TensorShape shape, int n, int c, int h, int w) { int index = n * shape.Height() * shape.Width() * shape.Channels() + - h * shape.Width() * shape.Channels() + - w * shape.Channels() + - c; + c * shape.Height() * shape.Width() + + h * shape.Width() + + w; return index; } }