Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xls/estimators/delay_model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ cc_test(
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/ir:source_location",
"//xls/ir:type",
"//xls/ir:value",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@googletest//:gtest",
],
Expand Down
124 changes: 88 additions & 36 deletions xls/estimators/delay_model/analyze_critical_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

#include "xls/estimators/delay_model/analyze_critical_path.h"

#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
#include <optional>
#include <string>
#include <utility>
Expand All @@ -40,50 +42,32 @@

namespace xls {

absl::StatusOr<std::vector<CriticalPathEntry>> AnalyzeCriticalPath(
namespace {

absl::StatusOr<NodeDelayEntries> AccumulateNodeDelays(
FunctionBase* f, std::optional<int64_t> clock_period_ps,
const DelayEstimator& delay_estimator,
absl::AnyInvocable<bool(Node*)> source_filter,
absl::AnyInvocable<bool(Node*)> sink_filter) {
struct NodeEntry {
Node* node;

// Delay of the node.
int64_t node_delay;

// The delay of the critical path in the graph up to and including this node
// (includes this node's delay).
int64_t critical_path_delay;

// The predecessor on the critical path through this node.
std::optional<Node*> critical_path_predecessor;

// Whether this node was delayed by a cycle boundary.
bool delayed_by_cycle_boundary;
};

// Map from each node to it's corresponding entry.
absl::flat_hash_map<Node*, NodeEntry> node_entries;
absl::AnyInvocable<bool(Node*)>& source_filter,
absl::AnyInvocable<bool(Node*)>& sink_filter) {
NodeDelayEntries entries;
entries.topo_sorted_nodes = TopoSort(f);

// The node with the greatest critical path delay.
std::optional<NodeEntry> latest_entry;

for (Node* node : TopoSort(f)) {
for (Node* node : entries.topo_sorted_nodes) {
if (!source_filter(node) &&
!absl::c_any_of(node->operands(), [&](Node* operand) {
return node_entries.contains(operand);
return entries.node_entries.contains(operand);
})) {
// This node is neither a source nor on a path from a source.
continue;
}
NodeEntry& entry = node_entries[node];
NodeDelayEntry& entry = entries.node_entries[node];
entry.node = node;

// The maximum delay from any path up to but not including `node`.
int64_t max_path_delay = 0;
for (Node* operand : node->operands()) {
auto it = node_entries.find(operand);
if (it == node_entries.end()) {
auto it = entries.node_entries.find(operand);
if (it == entries.node_entries.end()) {
// This operand is neither a source nor on a path from a source.
continue;
}
Expand Down Expand Up @@ -113,22 +97,35 @@ absl::StatusOr<std::vector<CriticalPathEntry>> AnalyzeCriticalPath(
if (!sink_filter(node)) {
continue;
}
if (!latest_entry.has_value() ||
latest_entry->critical_path_delay <= entry.critical_path_delay) {
latest_entry = entry;
if (!entries.latest.has_value() ||
entries.latest->critical_path_delay <= entry.critical_path_delay) {
entries.latest = entry;
}
}
return entries;
}

} // anonymous namespace

absl::StatusOr<std::vector<CriticalPathEntry>> AnalyzeCriticalPath(
FunctionBase* f, std::optional<int64_t> clock_period_ps,
const DelayEstimator& delay_estimator,
absl::AnyInvocable<bool(Node*)> source_filter,
absl::AnyInvocable<bool(Node*)> sink_filter) {
XLS_ASSIGN_OR_RETURN(NodeDelayEntries entries,
AccumulateNodeDelays(f, clock_period_ps, delay_estimator,
source_filter, sink_filter));

// `latest_entry` has no value for empty FunctionBases or if the source & sink
// filters removed all nodes.
if (!latest_entry.has_value()) {
if (!entries.latest.has_value()) {
return std::vector<CriticalPathEntry>();
}

// Starting with the operation with the longest path delay, walk back up its
// critical path constructing CriticalPathEntry's as we go.
std::vector<CriticalPathEntry> critical_path;
NodeEntry* entry = &(latest_entry.value());
NodeDelayEntry* entry = &(entries.latest.value());
while (true) {
critical_path.push_back(CriticalPathEntry{
.node = entry->node,
Expand All @@ -138,12 +135,67 @@ absl::StatusOr<std::vector<CriticalPathEntry>> AnalyzeCriticalPath(
if (!entry->critical_path_predecessor.has_value()) {
break;
}
entry = &node_entries.at(entry->critical_path_predecessor.value());
entry = &entries.node_entries.at(entry->critical_path_predecessor.value());
}

return std::move(critical_path);
}

absl::StatusOr<absl::flat_hash_map<Node*, int64_t>> SlackFromCriticalPath(
FunctionBase* f, std::optional<int64_t> clock_period_ps,
const DelayEstimator& delay_estimator,
absl::AnyInvocable<bool(Node*)> source_filter,
absl::AnyInvocable<bool(Node*)> sink_filter) {
XLS_ASSIGN_OR_RETURN(NodeDelayEntries entries,
AccumulateNodeDelays(f, clock_period_ps, delay_estimator,
source_filter, sink_filter));

absl::flat_hash_map<Node*, int64_t> node_slack;
for (auto node_iter = entries.topo_sorted_nodes.rbegin();
node_iter != entries.topo_sorted_nodes.rend(); ++node_iter) {
Node* node = *node_iter;
if (!entries.node_entries.contains(node)) {
continue;
}
const NodeDelayEntry& node_entry = entries.node_entries.at(node);

int64_t min_slack = std::numeric_limits<int64_t>::max();
bool has_any_users = false;
for (Node* user : node->users()) {
if (!entries.node_entries.contains(user)) {
continue;
}
has_any_users = true;

int64_t max_other_operand_delay = node_entry.critical_path_delay;
for (Node* operand : user->operands()) {
if (!entries.node_entries.contains(operand)) {
continue;
}
max_other_operand_delay =
std::max(max_other_operand_delay,
entries.node_entries.at(operand).critical_path_delay);
}

// A node's slack w.r.t a user is the user's slack plus how much less this
// node's delay is than the largest delay of the user's other operands.
min_slack =
std::min(min_slack, node_slack[user] + max_other_operand_delay -
node_entry.critical_path_delay);
}

// If at the end of the def-use chain, the slack is how much less this
// node's delay is than the critical path delay.
node_slack[node] =
has_any_users
? min_slack
: std::max((int64_t)0, entries.latest->critical_path_delay -
node_entry.critical_path_delay);
}

return node_slack;
}

std::string CriticalPathToString(
absl::Span<const CriticalPathEntry> critical_path,
std::optional<std::function<std::string(Node*)>> extra_info) {
Expand Down
53 changes: 53 additions & 0 deletions xls/estimators/delay_model/analyze_critical_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
#ifndef XLS_ESTIMATORS_DELAY_MODEL_ANALYZE_CRITICAL_PATH_H_
#define XLS_ESTIMATORS_DELAY_MODEL_ANALYZE_CRITICAL_PATH_H_

#include <algorithm>
#include <cstdint>
#include <functional>
#include <limits>
#include <optional>
#include <string>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/functional/any_invocable.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
Expand All @@ -47,6 +50,33 @@ struct CriticalPathEntry {
bool delayed_by_cycle_boundary;
};

struct NodeDelayEntry {
Node* node;

// Delay of the node.
int64_t node_delay;

// The delay of the critical path in the graph up to and including this node
// (includes this node's delay).
int64_t critical_path_delay;

// The predecessor on the critical path through this node.
std::optional<Node*> critical_path_predecessor;

// Whether this node was delayed by a cycle boundary.
bool delayed_by_cycle_boundary;
};

struct NodeDelayEntries {
std::vector<Node*> topo_sorted_nodes;

// Map from each node to it's corresponding entry.
absl::flat_hash_map<Node*, NodeDelayEntry> node_entries;

// The node with the greatest critical path delay.
std::optional<NodeDelayEntry> latest;
};

// Returns the critical path, decorated with the delay to produce the output of
// that node on the critical path.
//
Expand All @@ -65,6 +95,29 @@ absl::StatusOr<std::vector<CriticalPathEntry>> AnalyzeCriticalPath(
absl::AnyInvocable<bool(Node*)> source_filter = [](Node*) { return true; },
absl::AnyInvocable<bool(Node*)> sink_filter = [](Node*) { return true; });

// Returns the additional delay a node could have before it would alter the
// critical path. Any one node's slack assumes all other nodes remain unchanged.
//
// As an example, consider nodes with the following delays:
// a: 2
// b: 3
// c: 1
// d: 5
// e: 2
// a -> b -> d
// |--> c ---^
// |--> e
//
// The critical path goes through a, b, d with a critical path delay of 10. The
// slack on c is 2; any more than that and it would take b's place on the
// critical path. The slack on e is 6; any more than that would result in a
// critical path through a, e instead of a, b, d. The slack on a, b, and d is 0.
absl::StatusOr<absl::flat_hash_map<Node*, int64_t>> SlackFromCriticalPath(
FunctionBase* f, std::optional<int64_t> clock_period_ps,
const DelayEstimator& delay_estimator,
absl::AnyInvocable<bool(Node*)> source_filter = [](Node*) { return true; },
absl::AnyInvocable<bool(Node*)> sink_filter = [](Node*) { return true; });

// Returns a string representation of the critical-path. Includes delay
// information for each node as well as cumulative delay.
//
Expand Down
82 changes: 82 additions & 0 deletions xls/estimators/delay_model/analyze_critical_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

#include "xls/estimators/delay_model/analyze_critical_path.h"

#include <cstdint>
#include <optional>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "xls/common/status/matchers.h"
#include "xls/estimators/delay_model/delay_estimator.h"
Expand All @@ -29,6 +31,7 @@
#include "xls/ir/ir_matcher.h"
#include "xls/ir/ir_test_base.h"
#include "xls/ir/package.h"
#include "xls/ir/source_location.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"

Expand Down Expand Up @@ -169,5 +172,84 @@ TEST_F(AnalyzeCriticalPathTest, EmptyProc) {
EXPECT_TRUE(cp.empty());
}

TEST_F(AnalyzeCriticalPathTest, SlackFromCriticalPathFromExampleComment) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
auto x = fb.Param("x", u32);
auto y = fb.Param("y", u32);
auto z = fb.Param("z", u32);
// Path 'a' contributes 2 to delay.
auto a1 = fb.Negate(x, SourceInfo(), "a1");
auto a2 = fb.Reverse(a1, SourceInfo(), "a2");
// Path 'b' contributes 3 to delay.
auto b1 = fb.Add(a2, y, SourceInfo(), "b1");
auto b2 = fb.Negate(b1, SourceInfo(), "b2");
auto b3 = fb.Reverse(b2, SourceInfo(), "b3");
// Path 'c' contributes 1 to delay.
BValue c1 = fb.Add(a2, z, SourceInfo(), "c1");
// Path 'd' contributes 5 to delay.
auto d1 = fb.And(b3, c1, SourceInfo(), "d1");
auto d2 = fb.Negate(d1, SourceInfo(), "d2");
auto d3 = fb.Reverse(d2, SourceInfo(), "d3");
auto d4 = fb.Add(d3, z, SourceInfo(), "d4");
auto d5 = fb.Negate(d4, SourceInfo(), "d5");
// Path 'e' contributes 2 to delay.
auto e1 = fb.And(a2, y, SourceInfo(), "e1");
auto e2 = fb.And(e1, z, SourceInfo(), "e2");
auto return_val = fb.Tuple({d5, e2});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

XLS_ASSERT_OK_AND_ASSIGN(
(absl::flat_hash_map<Node*, int64_t> slacks),
SlackFromCriticalPath(f, /*clock_period_ps=*/std::nullopt,
*delay_estimator_));
EXPECT_EQ(slacks.at(return_val.node()), 0);
EXPECT_EQ(slacks.at(c1.node()), 2);
EXPECT_EQ(slacks.at(e1.node()), 6);
EXPECT_EQ(slacks.at(e2.node()), 6);
for (Node* node : {a1.node(), a2.node(), b1.node(), b2.node(), b3.node(),
d1.node(), d2.node(), d3.node(), d4.node(), d5.node()}) {
EXPECT_EQ(slacks.at(node), 0);
}
}

TEST_F(AnalyzeCriticalPathTest, SlackFromCriticalPathWithPartialView) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* u32 = p->GetBitsType(32);
auto x = fb.Param("x", u32);
auto y = fb.Param("y", u32);
// global critical path, should be ignored by test:
auto a1 = fb.Add(x, y, SourceInfo(), "a1");
auto a2 = fb.Negate(a1, SourceInfo(), "a2");
auto a3 = fb.Reverse(a2, SourceInfo(), "a3");
auto a4 = fb.And(a3, y, SourceInfo(), "a4");
auto a5 = fb.Or(a4, x, SourceInfo(), "a5");
// other nodes on shorter path:
auto b = fb.Subtract(x, y, SourceInfo(), "b");
auto c1 = fb.SDiv(x, y, SourceInfo(), "c1");
auto c2 = fb.Negate(c1, SourceInfo(), "c");
auto d = fb.And({b, c2, a5}, SourceInfo(), "d");
fb.Tuple({a5, d});
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

XLS_ASSERT_OK_AND_ASSIGN(
(absl::flat_hash_map<Node*, int64_t> slacks),
SlackFromCriticalPath(f, /*clock_period_ps=*/std::nullopt,
*delay_estimator_, [b, c1, c2, d](Node* n) {
return n == b.node() || n == c1.node() ||
n == c2.node() || n == d.node();
}));

EXPECT_EQ(slacks.at(b.node()), 1);
EXPECT_EQ(slacks.at(c1.node()), 0);
EXPECT_EQ(slacks.at(c2.node()), 0);
EXPECT_EQ(slacks.at(d.node()), 0);
for (Node* node : {a1.node(), a2.node(), a3.node(), a4.node(), a5.node()}) {
EXPECT_FALSE(slacks.contains(node));
}
}

} // namespace
} // namespace xls