@@ -27,20 +27,24 @@ struct RestrictPtrTraits {
27
27
};
28
28
#endif
29
29
30
- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
30
+ template <
31
+ typename T,
32
+ size_t N,
33
+ template <typename U> class PtrTraits = DefaultPtrTraits,
34
+ typename index_t = int64_t >
31
35
class TensorAccessorBase {
32
- public:
36
+ public:
33
37
typedef typename PtrTraits<T>::PtrType PtrType;
34
38
35
39
C10_HOST_DEVICE TensorAccessorBase (
36
40
PtrType data_,
37
41
const index_t * sizes_,
38
42
const index_t * strides_)
39
- : data_(data_) /* , sizes_(sizes_), strides_(strides_)*/ {
43
+ : data_(data_) /* , sizes_(sizes_), strides_(strides_)*/ {
40
44
// Originally, TensorAccessor is a view of sizes and strides as
41
45
// these are ArrayRef instances. Until torch::stable supports
42
46
// ArrayRef-like features, we store copies of sizes and strides:
43
- for (auto i= 0 ; i < N; ++i) {
47
+ for (auto i = 0 ; i < N; ++i) {
44
48
this ->sizes_ [i] = sizes_[i];
45
49
this ->strides_ [i] = strides_[i];
46
50
}
@@ -52,7 +56,8 @@ class TensorAccessorBase {
52
56
C10_HOST_DEVICE const PtrType data () const {
53
57
return data_;
54
58
}
55
- protected:
59
+
60
+ protected:
56
61
PtrType data_;
57
62
/*
58
63
const index_t* sizes_;
@@ -64,48 +69,65 @@ class TensorAccessorBase {
64
69
index_t strides_[N];
65
70
};
66
71
67
- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
68
- class TensorAccessor : public TensorAccessorBase <T,N,PtrTraits,index_t > {
69
- public:
72
+ template <
73
+ typename T,
74
+ size_t N,
75
+ template <typename U> class PtrTraits = DefaultPtrTraits,
76
+ typename index_t = int64_t >
77
+ class TensorAccessor : public TensorAccessorBase <T, N, PtrTraits, index_t > {
78
+ public:
70
79
typedef typename PtrTraits<T>::PtrType PtrType;
71
80
72
81
C10_HOST_DEVICE TensorAccessor (
73
82
PtrType data_,
74
83
const index_t * sizes_,
75
84
const index_t * strides_)
76
- : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}
85
+ : TensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
77
86
78
- C10_HOST_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t i) {
79
- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i,this ->sizes_ +1 ,this ->strides_ +1 );
87
+ C10_HOST_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
88
+ index_t i) {
89
+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
90
+ this ->data_ + this ->strides_ [0 ] * i,
91
+ this ->sizes_ + 1 ,
92
+ this ->strides_ + 1 );
80
93
}
81
94
82
- C10_HOST_DEVICE const TensorAccessor<T, N-1 , PtrTraits, index_t > operator [](index_t i) const {
83
- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i,this ->sizes_ +1 ,this ->strides_ +1 );
95
+ C10_HOST_DEVICE const TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
96
+ index_t i) const {
97
+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
98
+ this ->data_ + this ->strides_ [0 ] * i,
99
+ this ->sizes_ + 1 ,
100
+ this ->strides_ + 1 );
84
101
}
85
102
};
86
103
87
- template <typename T, template <typename U> class PtrTraits , typename index_t >
88
- class TensorAccessor <T,1 ,PtrTraits,index_t > : public TensorAccessorBase<T,1 ,PtrTraits,index_t > {
89
- public:
104
+ template <typename T, template <typename U> class PtrTraits , typename index_t >
105
+ class TensorAccessor <T, 1 , PtrTraits, index_t >
106
+ : public TensorAccessorBase<T, 1 , PtrTraits, index_t > {
107
+ public:
90
108
typedef typename PtrTraits<T>::PtrType PtrType;
91
109
92
110
C10_HOST_DEVICE TensorAccessor (
93
111
PtrType data_,
94
112
const index_t * sizes_,
95
113
const index_t * strides_)
96
- : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}
97
- C10_HOST_DEVICE T & operator [](index_t i) {
114
+ : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
115
+ C10_HOST_DEVICE T& operator [](index_t i) {
98
116
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
99
- return this ->data_ [this ->strides_ [0 ]* i];
117
+ return this ->data_ [this ->strides_ [0 ] * i];
100
118
}
101
- C10_HOST_DEVICE const T & operator [](index_t i) const {
102
- return this ->data_ [this ->strides_ [0 ]* i];
119
+ C10_HOST_DEVICE const T& operator [](index_t i) const {
120
+ return this ->data_ [this ->strides_ [0 ] * i];
103
121
}
104
122
};
105
123
106
- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
124
+ template <
125
+ typename T,
126
+ size_t N,
127
+ template <typename U> class PtrTraits = DefaultPtrTraits,
128
+ typename index_t = int64_t >
107
129
class GenericPackedTensorAccessorBase {
108
- public:
130
+ public:
109
131
typedef typename PtrTraits<T>::PtrType PtrType;
110
132
C10_HOST GenericPackedTensorAccessorBase (
111
133
PtrType data_,
@@ -116,13 +138,15 @@ class GenericPackedTensorAccessorBase {
116
138
std::copy (strides_, strides_ + N, std::begin (this ->strides_ ));
117
139
}
118
140
119
- template <typename source_index_t , class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
141
+ template <
142
+ typename source_index_t ,
143
+ class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
120
144
C10_HOST GenericPackedTensorAccessorBase (
121
145
PtrType data_,
122
146
const source_index_t * sizes_,
123
147
const source_index_t * strides_)
124
148
: data_(data_) {
125
- for (auto i= 0 ; i < N; ++i) {
149
+ for (auto i = 0 ; i < N; ++i) {
126
150
this ->sizes_ [i] = sizes_[i];
127
151
this ->strides_ [i] = strides_[i];
128
152
}
@@ -134,7 +158,8 @@ class GenericPackedTensorAccessorBase {
134
158
C10_HOST_DEVICE const PtrType data () const {
135
159
return data_;
136
160
}
137
- protected:
161
+
162
+ protected:
138
163
PtrType data_;
139
164
// NOLINTNEXTLINE(*c-arrays*)
140
165
index_t sizes_[N];
@@ -150,68 +175,101 @@ class GenericPackedTensorAccessorBase {
150
175
}
151
176
};
152
177
153
- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t >
154
- class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase <T,N,PtrTraits,index_t > {
155
- public:
178
+ template <
179
+ typename T,
180
+ size_t N,
181
+ template <typename U> class PtrTraits = DefaultPtrTraits,
182
+ typename index_t = int64_t >
183
+ class GenericPackedTensorAccessor
184
+ : public GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t > {
185
+ public:
156
186
typedef typename PtrTraits<T>::PtrType PtrType;
157
187
158
188
C10_HOST GenericPackedTensorAccessor (
159
189
PtrType data_,
160
190
const index_t * sizes_,
161
191
const index_t * strides_)
162
- : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
192
+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(
193
+ data_,
194
+ sizes_,
195
+ strides_) {}
163
196
164
197
// if index_t is not int64_t, we want to have an int64_t constructor
165
- template <typename source_index_t , class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
198
+ template <
199
+ typename source_index_t ,
200
+ class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
166
201
C10_HOST GenericPackedTensorAccessor (
167
202
PtrType data_,
168
203
const source_index_t * sizes_,
169
204
const source_index_t * strides_)
170
- : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}
205
+ : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(
206
+ data_,
207
+ sizes_,
208
+ strides_) {}
171
209
172
- C10_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t i) {
210
+ C10_DEVICE TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
211
+ index_t i) {
173
212
index_t * new_sizes = this ->sizes_ + 1 ;
174
213
index_t * new_strides = this ->strides_ + 1 ;
175
- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i, new_sizes, new_strides);
214
+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
215
+ this ->data_ + this ->strides_ [0 ] * i, new_sizes, new_strides);
176
216
}
177
217
178
- C10_DEVICE const TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](index_t i) const {
218
+ C10_DEVICE const TensorAccessor<T, N - 1 , PtrTraits, index_t > operator [](
219
+ index_t i) const {
179
220
const index_t * new_sizes = this ->sizes_ + 1 ;
180
221
const index_t * new_strides = this ->strides_ + 1 ;
181
- return TensorAccessor<T,N-1 ,PtrTraits,index_t >(this ->data_ + this ->strides_ [0 ]*i, new_sizes, new_strides);
222
+ return TensorAccessor<T, N - 1 , PtrTraits, index_t >(
223
+ this ->data_ + this ->strides_ [0 ] * i, new_sizes, new_strides);
182
224
}
183
225
};
184
226
185
- template <typename T, template <typename U> class PtrTraits , typename index_t >
186
- class GenericPackedTensorAccessor <T,1 ,PtrTraits,index_t > : public GenericPackedTensorAccessorBase<T,1 ,PtrTraits,index_t > {
187
- public:
227
+ template <typename T, template <typename U> class PtrTraits , typename index_t >
228
+ class GenericPackedTensorAccessor <T, 1 , PtrTraits, index_t >
229
+ : public GenericPackedTensorAccessorBase<T, 1 , PtrTraits, index_t > {
230
+ public:
188
231
typedef typename PtrTraits<T>::PtrType PtrType;
189
232
C10_HOST GenericPackedTensorAccessor (
190
233
PtrType data_,
191
234
const index_t * sizes_,
192
235
const index_t * strides_)
193
- : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
236
+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(
237
+ data_,
238
+ sizes_,
239
+ strides_) {}
194
240
195
- template <typename source_index_t , class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
241
+ template <
242
+ typename source_index_t ,
243
+ class = std::enable_if_t <std::is_same_v<source_index_t , int64_t >>>
196
244
C10_HOST GenericPackedTensorAccessor (
197
245
PtrType data_,
198
246
const source_index_t * sizes_,
199
247
const source_index_t * strides_)
200
- : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}
248
+ : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(
249
+ data_,
250
+ sizes_,
251
+ strides_) {}
201
252
202
- C10_DEVICE T & operator [](index_t i) {
253
+ C10_DEVICE T& operator [](index_t i) {
203
254
return this ->data_ [this ->strides_ [0 ] * i];
204
255
}
205
256
C10_DEVICE const T& operator [](index_t i) const {
206
- return this ->data_ [this ->strides_ [0 ]* i];
257
+ return this ->data_ [this ->strides_ [0 ] * i];
207
258
}
208
-
209
259
};
210
260
211
- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
212
- using PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t >;
261
+ template <
262
+ typename T,
263
+ size_t N,
264
+ template <typename U> class PtrTraits = DefaultPtrTraits>
265
+ using PackedTensorAccessor32 =
266
+ GenericPackedTensorAccessor<T, N, PtrTraits, int32_t >;
213
267
214
- template <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
215
- using PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t >;
268
+ template <
269
+ typename T,
270
+ size_t N,
271
+ template <typename U> class PtrTraits = DefaultPtrTraits>
272
+ using PackedTensorAccessor64 =
273
+ GenericPackedTensorAccessor<T, N, PtrTraits, int64_t >;
216
274
217
- } // namespace torchaudio::stable
275
+ } // namespace torchaudio::stable
0 commit comments