Skip to content

Commit c9b6bdf

Browse files
committed
fix(tensor): 更新 ndarray-layout
Signed-off-by: YdrMaster <[email protected]>
1 parent d80b378 commit c9b6bdf

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ causal-lm.path = "causal-lm"
3131
test-utils.path = "test-utils"
3232

3333
ggus = "0.4"
34-
ndarray-layout = "0.0"
34+
ndarray-layout = "0.1"
3535
log = "0.4"
3636
regex = "1.11"
3737
itertools = "0.13"

tensor/src/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl<T> Tensor<T> {
8484

8585
let merged = self
8686
.layout
87-
.merge(0..self.layout.ndim())
87+
.merge_be(0, self.layout.ndim())
8888
.expect("dense tensor is castable");
8989
let &[d] = merged.shape() else { unreachable!() };
9090
let &[s] = merged.strides() else {
@@ -177,7 +177,7 @@ impl<T> Tensor<T> {
177177
#[inline]
178178
pub fn is_contiguous(&self) -> bool {
179179
// 任意相邻的两个维度可以合并表示张量完全连续
180-
(2..=self.layout.ndim()).all(|i| self.layout.merge(i - 2..i).is_some())
180+
(2..=self.layout.ndim()).all(|i| self.layout.merge_be(i - 2, 2).is_some())
181181
}
182182

183183
/// 判断张量是否稠密存储。
@@ -186,7 +186,7 @@ impl<T> Tensor<T> {
186186
// 张量为稠密存储,当:
187187
self.layout
188188
// 所有维度可以合并成一个
189-
.merge(0..self.layout.ndim())
189+
.merge_be(0, self.layout.ndim())
190190
// 合并后元素之间步长等于元素的长度
191191
.is_some_and(|layout| {
192192
let [s] = layout.strides() else {
@@ -198,7 +198,7 @@ impl<T> Tensor<T> {
198198

199199
#[inline]
200200
pub fn get_contiguous_bytes(&self) -> Option<usize> {
201-
let layout = self.layout.merge(0..self.layout.ndim())?;
201+
let layout = self.layout.merge_be(0, self.layout.ndim())?;
202202
let &[size] = layout.shape() else {
203203
unreachable!()
204204
};
@@ -303,7 +303,7 @@ impl<T> Tensor<T> {
303303
#[inline]
304304
pub fn merge(self, range: Range<usize>) -> Option<Self> {
305305
self.layout
306-
.merge(range)
306+
.merge_be(range.start, range.len())
307307
.map(|layout| Self { layout, ..self })
308308
}
309309
}

0 commit comments

Comments
 (0)