Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/onnxruntime/core/framework/data_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,10 @@ class PrimitiveDataTypeBase : public DataTypeImpl {
return nullptr;
}

void SetDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) {
*const_cast<int32_t*>(&data_type_) = data_type;
}

int32_t GetDataType() const {
return data_type_;
}
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ class Tensor final {
*/
MLDataType DataType() const { return dtype_; }

/**
Sets the data type to an enum constant
*/
void SetElementType(ONNX_NAMESPACE::TensorProto_DataType data_type) {
const_cast<PrimitiveDataTypeBase*>(dtype_)->SetDataType(data_type);
}

/**
Returns the data type enum constant
@remarks Use utils::ToTensorProtoElementType<T> for comparison.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,7 @@ struct ProviderHost {
virtual ptrdiff_t Tensor__ByteOffset(const Tensor* p) = 0;
virtual size_t Tensor__SizeInBytes(const Tensor* p) = 0;
virtual const OrtMemoryInfo& Tensor__Location(const Tensor* p) = 0;
virtual void Tensor__SetElementType(Tensor* p, ONNX_NAMESPACE::TensorProto_DataType data_type) = 0;
virtual int32_t Tensor__GetElementType(const Tensor* p) = 0;
virtual MLDataType Tensor__DataType(const Tensor* p) = 0;
#ifdef ENABLE_STRIDED_TENSORS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,7 @@ struct Tensor final {
const OrtMemoryInfo& Location() const { return g_host->Tensor__Location(this); }

int32_t GetElementType() const { return g_host->Tensor__GetElementType(this); }
void SetElementType(ONNX_NAMESPACE::TensorProto_DataType data_type) { g_host->Tensor__SetElementType(this, data_type); }
MLDataType DataType() const { return g_host->Tensor__DataType(this); }
bool IsDataTypeString() const { return g_host->Tensor__IsDataTypeString(this); }

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,7 @@ struct ProviderHostImpl : ProviderHost {
ptrdiff_t Tensor__ByteOffset(const Tensor* p) override { return p->ByteOffset(); }
size_t Tensor__SizeInBytes(const Tensor* p) override { return p->SizeInBytes(); }
const OrtMemoryInfo& Tensor__Location(const Tensor* p) override { return p->Location(); }
void Tensor__SetElementType(Tensor* p, ONNX_NAMESPACE::TensorProto_DataType data_type) override { p->SetElementType(data_type); }
int32_t Tensor__GetElementType(const Tensor* p) override { return p->GetElementType(); }
MLDataType Tensor__DataType(const Tensor* p) override { return p->DataType(); }
#ifdef ENABLE_STRIDED_TENSORS
Expand Down