Skip to content

Commit 04e1b76

Browse files
committed
Merge remote-tracking branch 'origin/feature/fused-ops' into bump_to_1b2c8f10
2 parents 365c2d6 + b3913be commit 04e1b76

File tree

5 files changed

+183
-38
lines changed

5 files changed

+183
-38
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,8 @@ struct ConsolidateTransposeOptimization
430430
auto permsTy =
431431
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
432432
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
433-
Value permsValue =
434-
rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
433+
Value permsValue = rewriter.create<tosa::ConstOp>(transposeOp.getLoc(),
434+
permsTy, permsAttr);
435435

436436
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
437437
transposeOp, transposeOp.getResult().getType(),

mlir/lib/Target/Cpp/TranslateToCpp.cpp

+88-26
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ struct CppEmitter {
189189
/// Return the existing or a new name for a Value.
190190
StringRef getOrCreateName(Value val);
191191

192+
/// Return the existing or a new name for a loop induction variable of an
193+
/// emitc::ForOp.
194+
StringRef getOrCreateName(emitc::ForOp forOp);
195+
192196
// Returns the textual representation of a subscript operation.
193197
std::string getSubscriptName(emitc::SubscriptOp op);
194198

@@ -204,23 +208,39 @@ struct CppEmitter {
204208
/// Whether to map an mlir integer to a unsigned integer in C++.
205209
bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
206210

207-
/// RAII helper function to manage entering/exiting C++ scopes.
211+
/// Abstract RAII helper function to manage entering/exiting C++ scopes.
208212
struct Scope {
213+
~Scope() { emitter.labelInScopeCount.pop(); }
214+
215+
private:
216+
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
217+
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
218+
219+
protected:
209220
Scope(CppEmitter &emitter)
210221
: valueMapperScope(emitter.valueMapper),
211222
blockMapperScope(emitter.blockMapper), emitter(emitter) {
212-
emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
213223
emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
214224
}
215-
~Scope() {
216-
emitter.valueInScopeCount.pop();
217-
emitter.labelInScopeCount.pop();
225+
CppEmitter &emitter;
226+
};
227+
228+
/// RAII helper function to manage entering/exiting functions, while re-using
229+
/// value names.
230+
struct FunctionScope : Scope {
231+
FunctionScope(CppEmitter &emitter) : Scope(emitter) {
232+
// Re-use value names
233+
emitter.resetValueCounter();
218234
}
235+
};
219236

220-
private:
221-
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
222-
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
223-
CppEmitter &emitter;
237+
/// RAII helper function to manage entering/exiting emitc::forOp loops and
238+
/// handle induction variable naming.
239+
struct LoopScope : Scope {
240+
LoopScope(CppEmitter &emitter) : Scope(emitter) {
241+
emitter.increaseLoopNestingLevel();
242+
}
243+
~LoopScope() { emitter.decreaseLoopNestingLevel(); }
224244
};
225245

226246
/// Returns wether the Value is assigned to a C++ variable in the scope.
@@ -268,6 +288,15 @@ struct CppEmitter {
268288
/// This emitter will only emit translation units whos id matches this value.
269289
StringRef willOnlyEmitTu() { return onlyTu; }
270290

291+
// Resets the value counter to 0
292+
void resetValueCounter();
293+
294+
// Increases the loop nesting level by 1
295+
void increaseLoopNestingLevel();
296+
297+
// Decreases the loop nesting level by 1
298+
void decreaseLoopNestingLevel();
299+
271300
private:
272301
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
273302
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
@@ -292,11 +321,19 @@ struct CppEmitter {
292321
/// Map from block to name of C++ label.
293322
BlockMapper blockMapper;
294323

295-
/// The number of values in the current scope. This is used to declare the
296-
/// names of values in a scope.
297-
std::stack<int64_t> valueInScopeCount;
324+
/// Default values representing outermost scope
325+
llvm::ScopedHashTableScope<Value, std::string> defaultValueMapperScope;
326+
llvm::ScopedHashTableScope<Block *, std::string> defaultBlockMapperScope;
327+
298328
std::stack<int64_t> labelInScopeCount;
299329

330+
/// Keeps track of the amount of nested loops the emitter currently operates
331+
/// in.
332+
uint64_t loopNestingLevel{0};
333+
334+
/// Emitter-level count of created values to enable unique identifiers.
335+
unsigned int valueCount{0};
336+
300337
/// State of the current expression being emitted.
301338
ExpressionOp emittedExpression;
302339
SmallVector<int> emittedExpressionPrecedence;
@@ -915,7 +952,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
915952
}
916953

917954
static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
918-
919955
raw_indented_ostream &os = emitter.ostream();
920956

921957
// Utility function to determine whether a value is an expression that will be
@@ -934,12 +970,12 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
934970
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
935971
return failure();
936972
os << " ";
937-
os << emitter.getOrCreateName(forOp.getInductionVar());
973+
os << emitter.getOrCreateName(forOp);
938974
os << " = ";
939975
if (failed(emitter.emitOperand(forOp.getLowerBound())))
940976
return failure();
941977
os << "; ";
942-
os << emitter.getOrCreateName(forOp.getInductionVar());
978+
os << emitter.getOrCreateName(forOp);
943979
os << " < ";
944980
Value upperBound = forOp.getUpperBound();
945981
bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
@@ -950,13 +986,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
950986
if (upperBoundRequiresParentheses)
951987
os << ")";
952988
os << "; ";
953-
os << emitter.getOrCreateName(forOp.getInductionVar());
989+
os << emitter.getOrCreateName(forOp);
954990
os << " += ";
955991
if (failed(emitter.emitOperand(forOp.getStep())))
956992
return failure();
957993
os << ") {\n";
958994
os.indent();
959995

996+
CppEmitter::LoopScope lScope(emitter);
997+
960998
Region &forRegion = forOp.getRegion();
961999
auto regionOps = forRegion.getOps();
9621000

@@ -1043,8 +1081,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
10431081
}
10441082

10451083
static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
1046-
CppEmitter::Scope scope(emitter);
1047-
10481084
for (Operation &op : moduleOp) {
10491085
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
10501086
return failure();
@@ -1056,8 +1092,6 @@ static LogicalResult printOperation(CppEmitter &emitter, TranslationUnitOp tu) {
10561092
if (!emitter.shouldEmitTu(tu))
10571093
return success();
10581094

1059-
CppEmitter::Scope scope(emitter);
1060-
10611095
for (Operation &op : tu) {
10621096
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
10631097
return failure();
@@ -1220,7 +1254,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
12201254
return functionOp.emitOpError() << "cannot emit array type as result type";
12211255
}
12221256

1223-
CppEmitter::Scope scope(emitter);
1257+
CppEmitter::FunctionScope scope(emitter);
12241258
raw_indented_ostream &os = emitter.ostream();
12251259
if (failed(emitter.emitTypes(functionOp.getLoc(),
12261260
functionOp.getFunctionType().getResults())))
@@ -1248,7 +1282,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
12481282
"with multiple blocks needs variables declared at top");
12491283
}
12501284

1251-
CppEmitter::Scope scope(emitter);
1285+
CppEmitter::FunctionScope scope(emitter);
12521286
raw_indented_ostream &os = emitter.ostream();
12531287
if (functionOp.getSpecifiers()) {
12541288
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
@@ -1282,7 +1316,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
12821316

12831317
static LogicalResult printOperation(CppEmitter &emitter,
12841318
DeclareFuncOp declareFuncOp) {
1285-
CppEmitter::Scope scope(emitter);
12861319
raw_indented_ostream &os = emitter.ostream();
12871320

12881321
auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
@@ -1314,8 +1347,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
13141347
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
13151348
StringRef onlyTu, bool constantsAsVariables)
13161349
: os(os), declareVariablesAtTop(declareVariablesAtTop),
1317-
onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables) {
1318-
valueInScopeCount.push(0);
1350+
onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables),
1351+
defaultValueMapperScope(valueMapper),
1352+
defaultBlockMapperScope(blockMapper) {
13191353
labelInScopeCount.push(0);
13201354
}
13211355

@@ -1356,7 +1390,29 @@ StringRef CppEmitter::getOrCreateName(Value val) {
13561390
assert(!hasDeferredEmission(val.getDefiningOp()) &&
13571391
"cacheDeferredOpResult should have been called on this value, "
13581392
"update the emitOperation function.");
1359-
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1393+
1394+
valueMapper.insert(val, formatv("v{0}", ++valueCount));
1395+
}
1396+
return *valueMapper.begin(val);
1397+
}
1398+
1399+
/// Return the existing or a new name for a loop induction variable Value.
1400+
/// Loop induction variables follow natural naming: i, j, k,...
1401+
StringRef CppEmitter::getOrCreateName(emitc::ForOp forOp) {
1402+
Value val = forOp.getInductionVar();
1403+
1404+
if (!valueMapper.count(val)) {
1405+
1406+
int64_t identifier = 'i' + loopNestingLevel;
1407+
1408+
if (identifier >= 'i' && identifier <= 'z') {
1409+
valueMapper.insert(val,
1410+
formatv("{0}_{1}", (char)identifier, ++valueCount));
1411+
} else {
1412+
// If running out of letters, continue with zX
1413+
valueMapper.insert(
1414+
val, formatv("z{0}_{1}", identifier - 'z' - 1, ++valueCount));
1415+
}
13601416
}
13611417
return *valueMapper.begin(val);
13621418
}
@@ -1950,6 +2006,12 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
19502006
return success();
19512007
}
19522008

2009+
void CppEmitter::resetValueCounter() { valueCount = 0; }
2010+
2011+
void CppEmitter::increaseLoopNestingLevel() { loopNestingLevel++; }
2012+
2013+
void CppEmitter::decreaseLoopNestingLevel() { loopNestingLevel--; }
2014+
19532015
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
19542016
bool declareVariablesAtTop,
19552017
StringRef onlyTu,

mlir/test/Dialect/Tosa/transpose-fold.mlir

+9-9
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
// CHECK: }
77

88
func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
9-
%0 = arith.constant dense<[1, 2, 0]> : tensor<3xi32>
9+
%0 = "tosa.const"() <{value = dense<[1, 2, 0]> : tensor<3xi32>}> : () -> tensor<3xi32>
1010
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<2x3x1xi32>
11-
%2 = arith.constant dense<[2, 0, 1]> : tensor<3xi32>
11+
%2 = "tosa.const"() <{value = dense<[2, 0, 1]> :tensor<3xi32>}> : () -> tensor<3xi32>
1212
%3 = tosa.transpose %1, %2 : (tensor<2x3x1xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
1313
return %3 : tensor<1x2x3xi32>
1414
}
@@ -21,7 +21,7 @@ func.func @test_cancel_transpose_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<
2121
// CHECK: }
2222

2323
func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1x2x3xi32>) {
24-
%0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32>
24+
%0 = "tosa.const"() <{value = dense<[0, 1, 2]> : tensor<3xi32>}> : () -> tensor<3xi32>
2525
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3xi32>, tensor<3xi32>) -> tensor<1x2x3xi32>
2626
return %1 : tensor<1x2x3xi32>
2727
}
@@ -30,15 +30,15 @@ func.func @test_remove_identity_transpose(%arg0: tensor<1x2x3xi32>) -> (tensor<1
3030

3131
// CHECK-LABEL: func.func @test_do_not_cancel_different_transpose(
3232
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x4x5xi32>) -> tensor<5x4x3x2xi32> {
33-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
33+
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
3434
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
3535
// CHECK: return %[[VAL_2]] : tensor<5x4x3x2xi32>
3636
// CHECK: }
3737

3838
func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) -> (tensor<5x4x3x2xi32>) {
39-
%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
39+
%0 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
4040
%1 = tosa.transpose %arg0, %0 : (tensor<2x3x4x5xi32>, tensor<4xi32>) -> tensor<3x4x2x5xi32>
41-
%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
41+
%2 = "tosa.const"() <{value = dense<[3, 1, 0, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
4242
%3 = tosa.transpose %1, %2 : (tensor<3x4x2x5xi32>, tensor<4xi32>) -> tensor<5x4x3x2xi32>
4343
return %3 : tensor<5x4x3x2xi32>
4444
}
@@ -47,15 +47,15 @@ func.func @test_do_not_cancel_different_transpose(%arg0: tensor<2x3x4x5xi32>) ->
4747

4848
// CHECK-LABEL: func.func @test_prefer_compose_transpose(
4949
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4xi32>) -> tensor<4x3x2x1xi32> {
50-
// CHECK: %[[VAL_1:.*]] = arith.constant dense<[3, 2, 1, 0]> : tensor<4xi32>
50+
// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<[3, 2, 1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32>
5151
// CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
5252
// CHECK: return %[[VAL_2]] : tensor<4x3x2x1xi32>
5353
// CHECK: }
5454

5555
func.func @test_prefer_compose_transpose(%arg0: tensor<1x2x3x4xi32>) -> (tensor<4x3x2x1xi32>) {
56-
%0 = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32>
56+
%0 = "tosa.const"() <{value = dense<[1, 2, 0, 3]> : tensor<4xi32>}> : () -> tensor<4xi32>
5757
%1 = tosa.transpose %arg0, %0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<2x3x1x4xi32>
58-
%2 = arith.constant dense<[3, 1, 0, 2]> : tensor<4xi32>
58+
%2 = "tosa.const"() <{value = dense<[3, 1, 0, 2]> : tensor<4xi32>}> : () -> tensor<4xi32>
5959
%3 = tosa.transpose %1, %2 : (tensor<2x3x1x4xi32>, tensor<4xi32>) -> tensor<4x3x2x1xi32>
6060
return %3 : tensor<4x3x2x1xi32>
6161
}

mlir/test/Target/Cpp/emitc-constants-as-variables.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func.func @test() {
1212
return
1313
}
1414
// CPP-DEFAULT-LABEL: void test() {
15-
// CPP-DEFAULT-NEXT: for (size_t v1 = (size_t) 0; v1 < (size_t) 10; v1 += (size_t) 1) {
15+
// CPP-DEFAULT-NEXT: for (size_t [[V1:[^ ]*]] = (size_t) 0; [[V1]] < (size_t) 10; [[V1]] += (size_t) 1) {
1616
// CPP-DEFAULT-NEXT: }
1717
// CPP-DEFAULT-NEXT: return;
1818
// CPP-DEFAULT-NEXT: }

0 commit comments

Comments
 (0)