forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpython_variable.cpp
More file actions
3863 lines (3546 loc) · 134 KB
/
python_variable.cpp
File metadata and controls
3863 lines (3546 loc) · 134 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <ATen/DTensorState.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/Resize.h>
#include <c10/core/DeviceType.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <c10/util/FbcodeMaps.h>
#include <c10/util/SmallVector.h>
#include <c10/util/irange.h>
#include <pybind11/pytypes.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/Size.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/Types.h>
#include <torch/csrc/autograd/autograd.h>
#include <torch/csrc/autograd/edge.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/autograd/python_hook.h>
#include <torch/csrc/autograd/python_torch_functions.h>
#include <torch/csrc/autograd/python_variable_indexing.h>
#include <torch/csrc/autograd/utils/error_messages.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/distributed/Placement.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/tensor/python_tensor.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <torch/csrc/utils/pyobject_preservation.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_new.h>
#include <torch/csrc/utils/tensor_numpy.h>
#include <torch/csrc/utils/torch_dispatch_mode.h>
#include <ATen/ATen.h>
#include <structmember.h>
#include <cstdint>
#include <memory>
#include <sstream>
#include <utility>
#include <vector>
using namespace at;
using namespace torch;
using namespace torch::autograd;
using torch::utils::PyObjectPreservation;
namespace {
class OperatorArgsKwargsView {
public:
OperatorArgsKwargsView(
const c10::OperatorHandle& op,
const std::vector<c10::IValue>& arguments);
using args_iterator = const c10::IValue*;
args_iterator args_begin() const {
return arguments_.data();
}
args_iterator args_end() const {
return arguments_.data() + positional_default_start_;
}
auto num_positional_args() const {
return positional_default_start_;
}
auto kwarg_start_index() const {
return first_non_default_kwarg_;
}
struct kwargs_iterator {
kwargs_iterator() = default;
kwargs_iterator(const OperatorArgsKwargsView* parent, size_t current)
: parent_(parent), current_(current) {}
kwargs_iterator(const kwargs_iterator&) = default;
kwargs_iterator& operator=(const kwargs_iterator&) = default;
kwargs_iterator& operator++() {
do {
current_++;
} while (current_ < parent_->arguments_.size() &&
parent_->is_default(current_));
return *this;
}
kwargs_iterator operator++(int) {
auto copy = *this;
++(*this);
return copy;
}
const c10::IValue& operator*() const {
return parent_->arguments_[current_];
}
const c10::IValue* operator->() const {
return &operator*();
}
int64_t underlying_index() const {
return current_;
}
bool operator==(const kwargs_iterator& rhs) const {
return parent_ == rhs.parent_ && current_ == rhs.current_;
}
bool operator!=(const kwargs_iterator& rhs) {
return !(*this == rhs);
}
private:
const OperatorArgsKwargsView* parent_ = nullptr;
size_t current_ = 0;
};
kwargs_iterator kwargs_begin() const {
return kwargs_iterator(this, first_non_default_kwarg_);
}
kwargs_iterator kwargs_end() const {
return kwargs_iterator(this, arguments_.size());
}
private:
bool is_default(size_t idx) const {
const auto& arg = op_.schema().arguments()[idx];
if (!arg.default_value().has_value()) {
return false;
}
const auto& default_ivalue = *arg.default_value();
const auto& ivalue = arguments_[idx];
if (default_ivalue != ivalue) {
return false;
}
return true;
}
const c10::OperatorHandle& op_;
c10::ArrayRef<c10::IValue> arguments_;
// About all the pointers:
//
// f(int x, int y = 0, *, int z = 0)
// ^- arguments.size()
// ^- kwarg_only_start
// ^- positional_default_start
// ^- 0
int64_t positional_default_start_;
int64_t first_non_default_kwarg_;
};
OperatorArgsKwargsView::OperatorArgsKwargsView(
const c10::OperatorHandle& op,
const std::vector<c10::IValue>& arguments)
: op_(op), arguments_(arguments) {
// Find the split point between kwarg-only and regular. Since most functions
// don't have kwarg-only arguments, it is more efficient to scan from the
// right (but ideally, this would just be precomputed in FunctionSchema
// itself). (NB: minus one in the loop is because we're testing if the
// *next* argument is kwarg-only before we advance the starting index)
const int64_t signed_arguments_size = static_cast<int64_t>(arguments.size());
int64_t kwarg_only_start = signed_arguments_size;
for (; kwarg_only_start > 0; kwarg_only_start--) {
const auto& arg = op.schema().arguments()[kwarg_only_start - 1];
if (!arg.kwarg_only()) {
break;
}
}
// Find the first positional argument that isn't defaulted
positional_default_start_ = kwarg_only_start;
for (; positional_default_start_ > 0; positional_default_start_--) {
if (!is_default(positional_default_start_ - 1)) {
break;
}
}
// kwargs_iterator will skip default kwargs when incremented, but we
// need to skip any initial run of default kwargs ourselves.
first_non_default_kwarg_ = kwarg_only_start;
for (; first_non_default_kwarg_ < signed_arguments_size;
++first_non_default_kwarg_) {
if (!is_default(first_non_default_kwarg_)) {
break;
}
}
}
} // namespace
std::pair<py::object, py::dict> parseIValuesToPyArgsKwargs(
const c10::OperatorHandle& op,
const std::vector<c10::IValue>& arguments) {
TORCH_CHECK(
PyGILState_Check(),
"GIL must be held before you call parseIValuesToPyArgsKwargs");
const auto& schema = op.schema();
py::dict kwargs;
OperatorArgsKwargsView args_kwargs(op, arguments);
auto args = py::reinterpret_steal<py::object>(
PyTuple_New(args_kwargs.num_positional_args()));
auto schemaAwareToPyObject =
[&schema](size_t idx, const c10::IValue& argument) -> py::object {
const auto& arg = schema.arguments()[idx];
auto match = [&](c10::TypeKind kind) {
const auto& t = arg.real_type();
if (t->kind() == kind)
return true;
if (auto opt_t = t->cast<c10::OptionalType>()) {
if (opt_t->getElementType()->kind() == kind)
return true;
}
return false;
};
if (argument.isNone()) {
return py::none();
} else if (match(c10::ScalarTypeType::Kind)) {
auto* obj = getTHPDtype(static_cast<c10::ScalarType>(argument.toInt()));
return py::reinterpret_borrow<py::object>(
reinterpret_cast<PyObject*>(obj));
} else if (match(c10::LayoutType::Kind)) {
auto* obj = getTHPLayout(static_cast<c10::Layout>(argument.toInt()));
return py::reinterpret_borrow<py::object>(
reinterpret_cast<PyObject*>(obj));
} else if (match(c10::MemoryFormatType::Kind)) {
return py::cast(static_cast<c10::MemoryFormat>(argument.toInt()));
} else {
return torch::jit::toPyObject(argument);
}
};
// Populate positional arguments
size_t idx = 0;
for (auto argument_it = args_kwargs.args_begin();
argument_it != args_kwargs.args_end();
++argument_it) {
PyTuple_SET_ITEM(
args.ptr(),
idx,
schemaAwareToPyObject(idx, *argument_it).release().ptr());
idx++;
}
// Populate keyword arguments
for (auto argument_it = args_kwargs.kwargs_begin();
argument_it != args_kwargs.kwargs_end();
++argument_it) {
const auto& arg = schema.arguments()[argument_it.underlying_index()];
kwargs[py::cast(arg.name())] =
schemaAwareToPyObject(argument_it.underlying_index(), *argument_it);
}
return std::make_pair(std::move(args), std::move(kwargs));
}
void pushPyOutToStack(
const c10::OperatorHandle& op,
torch::jit::Stack* stack,
py::object out,
const char* msg) {
TORCH_CHECK(
PyGILState_Check(), "GIL must be held before you call pushPyOutToStack");
const auto& schema_returns = op.schema().returns();
const auto num_returns = schema_returns.size();
if (num_returns == 0) {
// Check that we got a None return from Python. Anything else is an error.
TORCH_CHECK(
out.is_none(),
"Expected ",
msg,
" for ",
op.operator_name(),
" to return None but it returned something else instead.");
} else if (num_returns == 1) {
torch::jit::push(
stack, torch::jit::toIValue(out.ptr(), schema_returns[0].real_type()));
} else {
auto outs = py::cast<py::sequence>(out);
for (const auto idx : c10::irange(outs.size())) {
torch::jit::push(
stack,
torch::jit::toIValue(
outs[idx].ptr(), schema_returns[idx].real_type()));
}
}
}
namespace {
c10::TensorImpl::SizesStridesPolicy parseSizesStridesPolicyArgument(
std::string_view arg) {
if (arg == "strides") {
return c10::TensorImpl::SizesStridesPolicy::CustomStrides;
}
if (arg == "sizes") {
return c10::TensorImpl::SizesStridesPolicy::CustomSizes;
}
TORCH_CHECK_VALUE(
false,
"Unknown sizes_strides_policy: ",
arg,
"; expected 'strides' or 'sizes'");
}
} // anonymous namespace
PyObject* THPVariableClass = nullptr;
PyObject* ParameterClass = nullptr;
// clang-tidy gets confused by static const
static constexpr const char* VOLATILE_WARNING =
"volatile was removed and now has no effect. Use "
"`with torch.no_grad():` instead.";
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls);
static bool check_has_torch_dispatch(PyObject* obj) {
if (THPVariable_CheckExact(obj)) {
return false;
}
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
return (
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_dispatch_impl());
}
// NOLINTNEXTLINE(*-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyObject* device_to_py_class_[static_cast<size_t>(
c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)];
void registerPythonTensorClass(
const std::string& device,
PyObject* python_tensor_class) {
c10::Device dev(device);
TORCH_CHECK(
dev.type() == kXLA, "Only the python class for XLA can be overridden");
if (device_to_py_class_[static_cast<size_t>(dev.type())] != nullptr) {
TORCH_WARN(
"Overriding a previously registered python class for ", dev.str());
}
device_to_py_class_[static_cast<size_t>(dev.type())] = python_tensor_class;
}
static PyObject* getPythonTensorClass(c10::Device d) {
return device_to_py_class_[static_cast<size_t>(d.type())];
}
void activateGPUTrace() {
c10::impl::GPUTrace::set_trace(getPyInterpreter());
}
static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) {
TORCH_CHECK(
PyObject_TypeCheck(obj, type),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
Py_TYPE(obj)->tp_name,
" which is not a subclass of the requested type");
}
// Generic for const Tensor& or Tensor&&
template <typename T>
static PyObject* THPVariable_WrapWithType(
T&& var,
std::optional<PyTypeObject*> desired_type) {
if (!var.defined()) {
Py_RETURN_NONE;
}
c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl();
c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot();
PyObject* obj = pyobj_slot->load_pyobj();
if (obj) {
if (desired_type) {
check_tensor_subclass(obj, *desired_type);
}
return Py_NewRef(obj);
}
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPVariableClass);
if (desired_type) {
type = *desired_type;
} else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) {
if (auto clazz = getPythonTensorClass(var.device())) {
type = reinterpret_cast<PyTypeObject*>(clazz);
}
}
obj = type->tp_alloc(type, 0);
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
// free-threaded Python.
PyUnstable_EnableTryIncRef(obj);
auto v = reinterpret_cast<THPVariable*>(obj);
new (&v->cdata) Tensor(std::forward<T>(var));
if (THPVariable_Unpack(obj).is_uniquely_owned()) {
// We can use a faster non-atomic code path if we have the only reference to
// a fresh Tensor.
PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj);
return obj;
}
PyObject* wrapper =
PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj);
if (wrapper != obj) {
// Another thread beat us to it
Py_DECREF(obj);
if (desired_type) {
check_tensor_subclass(wrapper, *desired_type);
}
return Py_NewRef(wrapper);
}
return obj;
}
PyObject* THPVariable_Wrap(at::TensorBase&& var) {
return THPVariable_WrapWithType(std::move(var), std::nullopt);
}
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
return THPVariable_WrapWithType(var, std::nullopt);
}
PyObject* THPVariable_Wrap(const at::TensorBase& var, PyTypeObject* type) {
return THPVariable_WrapWithType(var, type);
}
static PyObject* THPVariable_pynew(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs);
static PyObject* THPVariable_fix_weakref(PyObject* self, PyObject* noargs) {
const auto& var = THPVariable_Unpack(self);
Py_DECREF(THPVariable_Wrap(var));
Py_RETURN_NONE;
}
// Maps the given python callable over a vector of items, returning a vector
// of the same type of items.
template <typename T>
static std::vector<T> map_py_func(
const py::function& func,
const std::vector<T>& items) {
std::vector<T> new_items;
new_items.reserve(items.size());
for (auto& item : items) {
new_items.emplace_back(py::cast<T>(func(item)));
}
return new_items;
}
template <>
std::vector<at::Tensor> map_py_func(
const py::function& func,
const std::vector<at::Tensor>& items) {
std::vector<at::Tensor> new_items;
new_items.reserve(items.size());
for (auto& item : items) {
auto output = func(item);
if (output.is(py::none())) {
// treat None value as an undefined tensor
new_items.emplace_back();
} else {
new_items.emplace_back(py::cast<at::Tensor>(output));
}
}
return new_items;
}
static PyObject* view_func_impl(
PyObject* _self,
PyObject* args,
PyObject* kwargs,
bool check_has_same_meta) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(_self);
static PythonArgParser parser({
"_view_func(Tensor new_base, PyObject* symint_visitor_fn=None, PyObject* tensor_visitor_fn=None)",
});
ParsedArgs<3> parsed_args{};
auto r = parser.parse(_self, args, kwargs, parsed_args);
auto new_base = r.tensor(0);
PyObject* symint_visitor_fn = r.pyobject(1);
PyObject* tensor_visitor_fn = r.pyobject(2);
// Ensure that self is indeed a backward differentiable view
// If not, we return an undefined Tensor (None) and let the user handle it.
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
at::Tensor out;
if (diff_view_meta && diff_view_meta->has_bw_view()) {
const auto& view_info = diff_view_meta->get_backward_view();
// Ensure that the newly provided base is similar to the original base
if (!check_has_same_meta ||
torch::autograd::utils::has_same_meta(new_base, view_info.base_)) {
// Do the actual view replay
if (view_info.has_view_fn()) {
auto& view_func = view_info.view_fn();
// Determine new SymInt / tensor state as needed.
std::optional<std::vector<c10::SymInt>> new_symints = std::nullopt;
if (symint_visitor_fn != Py_None) {
new_symints = map_py_func(
py::cast<py::function>(symint_visitor_fn),
view_func.get_symints());
}
std::optional<std::vector<at::Tensor>> new_tensors = std::nullopt;
if (tensor_visitor_fn != Py_None) {
new_tensors = map_py_func(
py::cast<py::function>(tensor_visitor_fn),
view_func.get_tensors());
}
// call view func
if (new_symints.has_value() || new_tensors.has_value()) {
out = (*view_func.clone_and_set(new_symints, new_tensors))(new_base);
} else {
out = view_func(new_base);
}
} else {
out = new_base.as_strided(
self.sizes(), self.strides(), self.storage_offset());
}
}
}
return THPVariable_Wrap(out);
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_view_func(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/true);
}
static PyObject* THPVariable_view_func_unsafe(
PyObject* self_,
PyObject* args,
PyObject* kwargs) {
return view_func_impl(self_, args, kwargs, /*check_has_same_meta=*/false);
}
static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(self_);
TORCH_CHECK(
THPVariable_Check(arg),
"_rev_view_func expect a single argument that is a Tensor");
const auto& new_view = THPVariable_Unpack(arg);
// Ensure that self is indeed a backward differentiable view
// If not, we return an undefined Tensor (None) and let the user handle it.
auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
at::Tensor out;
if (diff_view_meta && diff_view_meta->has_bw_view()) {
const auto& view_info = diff_view_meta->get_backward_view();
// Do the actual view replay
TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found");
out = view_info.rev_view_fn()(new_view);
}
return THPVariable_Wrap(out);
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_rev_view_func_unsafe(
PyObject* self_,
PyObject* arg) {
return rev_view_func_impl(self_, arg);
}
// Instantiates a subclass of self with the same data.
static PyObject* THPVariable_as_subclass(
PyObject* _self,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
const auto& self = THPVariable_Unpack(_self);
static PythonArgParser parser({
"as_subclass(PyObject* cls)",
});
ParsedArgs<1> parsed_args{};
auto r = parser.parse(_self, args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls);
// guard completely turns off torch dispatch modes, doesn't just pop off the
// stack
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
c10::impl::DisablePythonDispatcher dpd_g;
PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls);
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS
}
static PyObject* THPVariable_make_subclass(
PyObject* _ignored,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser({
"_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, std::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, Device? device_for_backend_keys=None)",
});
ParsedArgs<7> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls);
// guard completely turns off torch dispatch modes, doesn't just pop off the
// stack
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
c10::impl::DisablePythonDispatcher dpd_g;
auto data =
r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED)
// We set `data`'s `allow_tensor_metadata_change` to true here, because we
// want to allow the following use case for backward compatibility:
//
// ```python
// rnn = torch.nn.RNN(100, 100, 2)
// # The following calls `torch._cudnn_rnn_flatten_weight(rnn._flat_weights,
// ...)`, # which changes storage of `rnn`'s weights in-place
// rnn.flatten_parameters()
// ```
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
data.set_requires_grad(r.toBool(2));
const auto sizes_strides_policy = r.stringViewOptional(3);
if (sizes_strides_policy.has_value()) {
data.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
}
if (r.toBool(4)) {
data.unsafeGetTensorImpl()->set_python_custom_device(true);
}
if (r.toBool(5)) {
data.unsafeGetTensorImpl()->set_python_custom_layout(true);
}
if (!r.isNone(6)) {
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
}
PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls);
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS
}
// Shared code factored out of THPVariable_make_wrapper_subclass and
// THPVariable_dtensor__new__.
static Tensor make_tensor_for_subclass_helper(
SymIntArrayRef sym_sizes,
OptionalSymIntArrayRef sym_strides,
const std::optional<c10::SymInt>& sym_storage_offset,
const TensorOptions& options,
const std::optional<c10::SymInt>& storage_size,
std::optional<DispatchKeySet> extra_dispatch_keys) {
AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
tracer::impl::NoTracerDispatchMode tracer_guard{};
c10::SymInt size_bytes;
auto dtype_itemsize = static_cast<int64_t>(options.dtype().itemsize());
if (storage_size.has_value()) {
size_bytes = storage_size.value();
} else if (sym_strides.has_value()) {
size_bytes = at::detail::computeStorageNbytes(
sym_sizes,
sym_strides.value(),
dtype_itemsize,
sym_storage_offset.value_or(0));
} else {
size_bytes = at::detail::computeStorageNbytesContiguous(
sym_sizes, dtype_itemsize, sym_storage_offset.value_or(0));
}
// We use storages **only** to track aliasing of subclasses during tracing.
// The actual data pointers are not valid.
Storage storage{
Storage::use_byte_size_t{},
size_bytes,
at::DataPtr{nullptr, options.device()},
/*allocator=*/c10::GetAllocator(c10::kMeta),
/*resizable=*/true};
auto keys = c10::DispatchKeySet({options.computeDispatchKey()});
if (extra_dispatch_keys.has_value()) {
keys = keys | *extra_dispatch_keys;
}
Tensor tensor = at::detail::make_tensor<TensorImpl>(
std::move(storage), keys, options.dtype());
TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
if (sym_strides.has_value()) {
tensor_impl->set_sizes_and_strides(
sym_sizes, sym_strides.value(), sym_storage_offset);
} else {
TORCH_CHECK(
!sym_storage_offset.has_value(),
"setting storage offset without stride not supported");
tensor_impl->generic_set_sizes_contiguous(sym_sizes);
}
return tensor;
}
static PyObject* THPVariable_make_wrapper_subclass(
PyObject* /*unused*/,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
// NB: pin_memory doesn't actually do anything
// TODO: strides variant?
// cls: Python subclass type
// size, strides, storage_offset, memory_format, dtype: self-explanatory
// layout: memory layout, e.g. for types of Nested Tensors or other sparse
// tensors
// pin_memory, requires_grad: self-explanatory
// dispatch_sizes_strides_policy: string - which sizes/strides we should
// dispatch to a custom python implementation.
// dispatch_device: whether to dispatch to a custom python implementation
// for device
// dispatch_layout: whether to dispatch to a custom python implementation
// for layout
// _extra_dispatch_keys: additional dispatch keys to add to the tensor
// storage_size: if provided, skip storage size calculation and just use the
// value provided. One use case is for Nested Tensor, where the
// storage size cannot be calculated from the sizes/strides
// (because they contain a NestedInt).
static PythonArgParser parser({
"_make_wrapper_subclass(PyObject* cls, SymIntArrayRef size, SymIntArrayRef? strides=None, "
"SymInt? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
"std::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False, bool dispatch_layout=False, "
"DispatchKeySet _extra_dispatch_keys=None, SymInt? storage_size=None)",
});
ParsedArgs<15> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TENSOR_SUBTYPE(cls);
// This is an important safety check; without it, the default behavior will be
// to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
// key, which will immediately segfault because the data pointer is null. By
// forcing users to define __torch_dispatch__ we ensure this does not happen
// TODO: This check is not complete; because the user can disable torch
// dispatch and then go again, triggering segfault. TBH I'm thinking I want
// to delete this function entirely
py::object attr = PyObject_FastGetAttrString(cls, "__torch_dispatch__");
TORCH_CHECK_TYPE(
attr.ptr() != nullptr &&
attr.ptr() != torch::disabled_torch_dispatch_impl(),
((PyTypeObject*)cls)->tp_name,
" must define __torch_dispatch__");
const auto options = TensorOptions()
.dtype(r.scalartype(5))
.device(r.device(7))
.layout(r.layoutOptional(6))
// NB: long standing issue, requires_grad is not
// respected here; you have to set it post facto, see
// https://github.com/pytorch/pytorch/issues/26428
// .requires_grad(r.toBool(7))
.pinned_memory(r.toBool(8));
// don't bother releasing GIL here, as we are not allocating any nontrivial
// data
auto sym_sizes = r.symintlist(1);
auto sym_strides_own = r.symintlistOptional(2);
Tensor tensor = make_tensor_for_subclass_helper(
/*sym_sizes=*/r.symintlist(1),
/*sym_strides=*/r.symintlistOptional(2),
/*sym_storage_offset=*/r.toSymIntOptional(3),
options,
/*storage_size=*/r.toSymIntOptional(14),
r.toDispatchKeySetOptional(13));
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
parseSizesStridesPolicyArgument(*sizes_strides_policy));
}
tensor.set_requires_grad(r.toBool(9));
if (r.toBool(11)) {
tensor.unsafeGetTensorImpl()->set_python_custom_device(true);
}
if (r.toBool(12)) {
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
}
return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls);
END_HANDLE_TH_ERRORS
}
#if IS_PYBIND_2_13_PLUS
#define DEFINE_CACHING_PYTHON_IMPORT_GETTER(name, import_expr) \
static py::handle name() { \
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> \
storage; \
return storage \
.call_once_and_store_result( \
[]() -> py::object { return import_expr; }) \
.get_stored(); \
}
#else
#define DEFINE_CACHING_PYTHON_IMPORT_GETTER(name, import_expr) \
static py::handle name() { \
static py::handle storage = py::object(import_expr).release(); \
return storage; \
}
#endif
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_dtensor_class_impl,
py::module::import("torch.distributed.tensor").attr("DTensor"))
py::handle get_dtensor_class() {
return get_dtensor_class_impl();
}
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_dtensor_spec_class,
py::module::import("torch.distributed.tensor")
.attr("_dtensor_spec")
.attr("DTensorSpec"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_replicate_class,
py::module::import("torch.distributed.tensor")
.attr("placement_types")
.attr("Replicate"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_tensor_meta_class,
py::module::import("torch.distributed.tensor")
.attr("_dtensor_spec")
.attr("TensorMeta"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_dtensor_op_dispatcher,
py::module::import("torch.distributed.tensor")
.attr("DTensor")
.attr("_op_dispatcher"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_dtensor_dispatch,
py::module::import("torch.distributed.tensor")
.attr("DTensor")
.attr("_op_dispatcher")
.attr("_dispatch_fast_path_python_tail"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_dtensor_dispatcher_wrap,
py::module::import("torch.distributed.tensor")
.attr("DTensor")
.attr("_op_dispatcher")
.attr("wrap"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_dtensor_get_local_results_slow_path,
py::module::import("torch")
.attr("distributed")
.attr("tensor")
.attr("DTensor")
.attr("_op_dispatcher")
.attr("_dispatch_get_local_results_slow_path"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_output_sharding_class,
py::module::import("torch.distributed.tensor")
.attr("_op_schema")
.attr("OutputSharding"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_op_strategy_class,
py::module::import("torch.distributed.tensor")
.attr("_op_schema")
.attr("OpStrategy"))
DEFINE_CACHING_PYTHON_IMPORT_GETTER(
get_tuple_strategy_class,
py::module::import("torch.distributed.tensor")
.attr("_op_schema")
.attr("TupleStrategy"))
static bool arg_type_tensor_or_tensor_list_like(py::handle arg) {
const auto dtensor_spec_class = get_dtensor_spec_class();
const auto op_strategy_class = get_op_strategy_class();
const auto tuple_strategy_class = get_tuple_strategy_class();
if (py::isinstance(arg, dtensor_spec_class) ||
py::isinstance(arg, op_strategy_class) ||
py::isinstance(arg, tuple_strategy_class)) {
return true;
}
if (!PyList_Check(arg.ptr())) {
return false;
}
py::list arg_list = py::reinterpret_borrow<py::list>(arg);
for (const auto e : arg_list) {
if (!e.is_none() && !py::isinstance(e, dtensor_spec_class) &&
!py::isinstance(e, op_strategy_class) &&
!py::isinstance(e, tuple_strategy_class)) {
return false;
}
}
return true;
}
#if IS_PYTHON_3_11_PLUS
#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_)
#else
#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) _(__name__)
#endif
#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \
MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) \
_(_comparison_key) \
_(_custom_op_handlers) \
_(_local_tensor) \
_(_spec) \
_(_unwrap_to_op_info_impl) \
_(args_schema) \
_(compute_mesh) \
_(device_mesh) \
_(dtype) \
_(get_coordinate) \
_(kwargs_schema) \
_(ndim) \
_(needs_pytree) \
_(needs_redistribute) \
_(op) \
_(op_to_schema_info) \
_(op_to_schema_info_for_single_dim_strategy) \
_(output_sharding) \
_(output_spec) \
_(schema_info) \
_(shape) \
_(sharding_propagator) \
_(size) \
_(static_argnum) \
_(static_kwargkey) \
_(stride) \
_(tensor_meta)
struct DTensorInternedStrings {
#define DECLARE_INTERNED_STRING_VARIABLE(s) PyObject* s;
FOR_EACH_DTENSOR_INTERNED_STRING(DECLARE_INTERNED_STRING_VARIABLE)
#undef DECLARE_INTERNED_STRING_VARIABLE
};
static DTensorInternedStrings dtensor_interned_strings;
#ifdef USE_DISTRIBUTED
static bool intern_dtensor_strings() {
#define INTERN_DTENSOR_STRING(s) \
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dtensor_interned_strings.s == nullptr); \
dtensor_interned_strings.s = PyUnicode_InternFromString(#s); \
if (dtensor_interned_strings.s == nullptr) { \
return false; \
}
FOR_EACH_DTENSOR_INTERNED_STRING(INTERN_DTENSOR_STRING);
#undef INTERN_DTENSOR_STRING
return true;
}
#endif