summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorConnor Abbott <cwabbott0@gmail.com>2017-10-06 19:21:39 -0400
committerConnor Abbott <cwabbott0@gmail.com>2017-10-06 19:21:39 -0400
commitf3f77c6d646c4776df535433b187832417d42a29 (patch)
treef902582aef06489db61fce8fae6d78cdcb221b18
parent83a4b45e45cdca2cf396bd7d5f6e7b5ffe42ba4a (diff)
nir/algebraic: refactor to match on values instead of sourcesnir-equality-saturation
Conceptually, given a match_value, we should be able to know whether it matches an ssa_def without caring about which instruction uses the ssa_def. But we were passing around an instruction + source combo everywhere, even if all we care about is the actual value. This becomes more important for equality saturation, where we want to try and match many different values in the same equivalence class without changing the parent instruction which uses that value. Unfortunately, this got a little complicated to untangle, since a bunch of nir_search helpers actually inspect the parent instruction, even though they don't need to.
-rw-r--r--src/compiler/nir/nir.c20
-rw-r--r--src/compiler/nir/nir.h1
-rw-r--r--src/compiler/nir/nir_opt_algebraic.py4
-rw-r--r--src/compiler/nir/nir_search.c139
-rw-r--r--src/compiler/nir/nir_search.h3
-rw-r--r--src/compiler/nir/nir_search_helpers.h71
6 files changed, 126 insertions, 112 deletions
diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c
index afd4d1a723..58f7649e7f 100644
--- a/src/compiler/nir/nir.c
+++ b/src/compiler/nir/nir.c
@@ -1375,19 +1375,25 @@ nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state)
}
nir_const_value *
-nir_src_as_const_value(nir_src src)
+nir_ssa_def_as_const_value(nir_ssa_def *def)
{
- if (!src.is_ssa)
- return NULL;
-
- if (src.ssa->parent_instr->type != nir_instr_type_load_const)
- return NULL;
+ if (def->parent_instr->type != nir_instr_type_load_const)
+ return false;
- nir_load_const_instr *load = nir_instr_as_load_const(src.ssa->parent_instr);
+ nir_load_const_instr *load = nir_instr_as_load_const(def->parent_instr);
return &load->value;
}
+nir_const_value *
+nir_src_as_const_value(nir_src src)
+{
+ if (!src.is_ssa)
+ return NULL;
+
+ return nir_ssa_def_as_const_value(src.ssa);
+}
+
/**
* Returns true if the source is known to be dynamically uniform. Otherwise it
* returns false which means it may or may not be dynamically uniform but it
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index bb5aba605a..2546ad51e0 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2214,6 +2214,7 @@ bool nir_foreach_dest(nir_instr *instr, nir_foreach_dest_cb cb, void *state);
bool nir_foreach_src(nir_instr *instr, nir_foreach_src_cb cb, void *state);
nir_const_value *nir_src_as_const_value(nir_src src);
+nir_const_value *nir_ssa_def_as_const_value(nir_ssa_def *def);
bool nir_src_is_dynamically_uniform(nir_src src);
bool nir_srcs_equal(nir_src src1, nir_src src2);
void nir_instr_rewrite_src(nir_instr *instr, nir_src *src, nir_src new_src);
diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py
index ad75228a50..df94c8f156 100644
--- a/src/compiler/nir/nir_opt_algebraic.py
+++ b/src/compiler/nir/nir_opt_algebraic.py
@@ -69,10 +69,10 @@ optimizations = [
(('idiv', a, 1), a),
(('umod', a, 1), 0),
(('imod', a, 1), 0),
- (('udiv', a, '#b@32(is_pos_power_of_two)'), ('ushr', a, ('find_lsb', b))),
+ (('udiv', a, '#b@32(is_power_of_two)'), ('ushr', a, ('find_lsb', b))),
(('idiv', a, '#b@32(is_pos_power_of_two)'), ('imul', ('isign', a), ('ushr', ('iabs', a), ('find_lsb', b))), 'options->lower_idiv'),
(('idiv', a, '#b@32(is_neg_power_of_two)'), ('ineg', ('imul', ('isign', a), ('ushr', ('iabs', a), ('find_lsb', ('iabs', b))))), 'options->lower_idiv'),
- (('umod', a, '#b(is_pos_power_of_two)'), ('iand', a, ('isub', b, 1))),
+ (('umod', a, '#b(is_power_of_two)'), ('iand', a, ('isub', b, 1))),
(('fneg', ('fneg', a)), a),
(('ineg', ('ineg', a)), a),
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index dec56fee74..c9fcef2832 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -42,25 +42,24 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
+static bool src_is_type(nir_src src, nir_alu_type type);
+
/**
* Check if a source produces a value of the given type.
*
* Used for satisfying 'a@type' constraints.
*/
static bool
-src_is_type(nir_src src, nir_alu_type type)
+value_is_type(nir_ssa_def *def, nir_alu_type type)
{
assert(type != nir_type_invalid);
- if (!src.is_ssa)
- return false;
-
/* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
if (nir_alu_type_get_base_type(type) == nir_type_bool)
type = nir_type_bool;
- if (src.ssa->parent_instr->type == nir_instr_type_alu) {
- nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
+ if (def->parent_instr->type == nir_instr_type_alu) {
+ nir_alu_instr *src_alu = nir_instr_as_alu(def->parent_instr);
nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
if (type == nir_type_bool) {
@@ -78,8 +77,8 @@ src_is_type(nir_src src, nir_alu_type type)
}
return nir_alu_type_get_base_type(output_type) == type;
- } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
- nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
+ } else if (def->parent_instr->type == nir_instr_type_intrinsic) {
+ nir_intrinsic_instr *intr = nir_instr_as_intrinsic(def->parent_instr);
if (type == nir_type_bool) {
return intr->intrinsic == nir_intrinsic_load_front_face ||
@@ -92,83 +91,61 @@ src_is_type(nir_src src, nir_alu_type type)
}
static bool
-match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
- unsigned num_components, const uint8_t *swizzle,
- struct match_state *state)
+src_is_type(nir_src src, nir_alu_type type)
{
- uint8_t new_swizzle[4];
-
- /* Searching only works on SSA values because, if it's not SSA, we can't
- * know if the value changed between one instance of that value in the
- * expression and another. Also, the replace operation will place reads of
- * that value right before the last instruction in the expression we're
- * replacing so those reads will happen after the original reads and may
- * not be valid if they're register reads.
- */
- if (!instr->src[src].src.is_ssa)
+ if (!src.is_ssa)
return false;
- /* If the source is an explicitly sized source, then we need to reset
- * both the number of components and the swizzle.
- */
- if (nir_op_infos[instr->op].input_sizes[src] != 0) {
- num_components = nir_op_infos[instr->op].input_sizes[src];
- swizzle = identity_swizzle;
- }
-
- for (unsigned i = 0; i < num_components; ++i)
- new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
-
- /* If the value has a specific bit size and it doesn't match, bail */
- if (value->bit_size &&
- nir_src_bit_size(instr->src[src].src) != value->bit_size)
- return false;
+ return value_is_type(src.ssa, type);
+}
+static bool
+match_value(const nir_search_value *value, nir_ssa_def *def,
+ unsigned num_components, const uint8_t *swizzle,
+ struct match_state *state)
+{
switch (value->type) {
case nir_search_value_expression:
- if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
+ if (def->parent_instr->type != nir_instr_type_alu)
return false;
return match_expression(nir_search_value_as_expression(value),
- nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
- num_components, new_swizzle, state);
+ nir_instr_as_alu(def->parent_instr),
+ num_components, swizzle, state);
case nir_search_value_variable: {
nir_search_variable *var = nir_search_value_as_variable(value);
assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
if (state->variables_seen & (1 << var->variable)) {
- if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
+ if (state->variables[var->variable].src.ssa != def)
return false;
- assert(!instr->src[src].abs && !instr->src[src].negate);
-
for (unsigned i = 0; i < num_components; ++i) {
- if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
+ if (state->variables[var->variable].swizzle[i] != swizzle[i])
return false;
}
return true;
} else {
- if (var->is_constant &&
- instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
+ if (var->is_constant && def->parent_instr->type != nir_instr_type_load_const)
return false;
- if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
+ if (var->cond && !var->cond(def, num_components, swizzle))
return false;
if (var->type != nir_type_invalid &&
- !src_is_type(instr->src[src].src, var->type))
+ !value_is_type(def, var->type))
return false;
state->variables_seen |= (1 << var->variable);
- state->variables[var->variable].src = instr->src[src].src;
+ state->variables[var->variable].src.ssa = def;
state->variables[var->variable].abs = false;
state->variables[var->variable].negate = false;
for (unsigned i = 0; i < 4; ++i) {
if (i < num_components)
- state->variables[var->variable].swizzle[i] = new_swizzle[i];
+ state->variables[var->variable].swizzle[i] = swizzle[i];
else
state->variables[var->variable].swizzle[i] = 0;
}
@@ -180,14 +157,11 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
case nir_search_value_constant: {
nir_search_constant *const_val = nir_search_value_as_constant(value);
- if (!instr->src[src].src.is_ssa)
- return false;
-
- if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
+ if (def->parent_instr->type != nir_instr_type_load_const)
return false;
nir_load_const_instr *load =
- nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
+ nir_instr_as_load_const(def->parent_instr);
switch (const_val->type) {
case nir_type_float:
@@ -195,10 +169,10 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
double val;
switch (load->def.bit_size) {
case 32:
- val = load->value.f32[new_swizzle[i]];
+ val = load->value.f32[swizzle[i]];
break;
case 64:
- val = load->value.f64[new_swizzle[i]];
+ val = load->value.f64[swizzle[i]];
break;
default:
unreachable("unknown bit size");
@@ -215,7 +189,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
switch (load->def.bit_size) {
case 32:
for (unsigned i = 0; i < num_components; ++i) {
- if (load->value.u32[new_swizzle[i]] !=
+ if (load->value.u32[swizzle[i]] !=
(uint32_t)const_val->data.u)
return false;
}
@@ -223,7 +197,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
case 64:
for (unsigned i = 0; i < num_components; ++i) {
- if (load->value.u64[new_swizzle[i]] != const_val->data.u)
+ if (load->value.u64[swizzle[i]] != const_val->data.u)
return false;
}
return true;
@@ -243,6 +217,43 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
}
static bool
+match_alu_src(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
+ unsigned num_components, const uint8_t *swizzle,
+ struct match_state *state)
+{
+ uint8_t new_swizzle[4];
+
+ /* Searching only works on SSA values because, if it's not SSA, we can't
+ * know if the value changed between one instance of that value in the
+ * expression and another. Also, the replace operation will place reads of
+ * that value right before the last instruction in the expression we're
+ * replacing so those reads will happen after the original reads and may
+ * not be valid if they're register reads.
+ */
+ if (!instr->src[src].src.is_ssa)
+ return false;
+
+ /* If the source is an explicitly sized source, then we need to reset
+ * both the number of components and the swizzle.
+ */
+ if (nir_op_infos[instr->op].input_sizes[src] != 0) {
+ num_components = nir_op_infos[instr->op].input_sizes[src];
+ swizzle = identity_swizzle;
+ }
+
+ for (unsigned i = 0; i < num_components; ++i)
+ new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
+
+ /* If the value has a specific bit size and it doesn't match, bail */
+ if (value->bit_size &&
+ nir_src_bit_size(instr->src[src].src) != value->bit_size)
+ return false;
+
+ return match_value(value, instr->src[src].src.ssa, num_components, swizzle,
+ state);
+}
+
+static bool
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state)
@@ -287,8 +298,8 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
bool matched = true;
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
- if (!match_value(expr->srcs[i], instr, i, num_components,
- swizzle, state)) {
+ if (!match_alu_src(expr->srcs[i], instr, i, num_components,
+ swizzle, state)) {
matched = false;
break;
}
@@ -306,12 +317,12 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
*/
state->variables_seen = variables_seen_stash;
- if (!match_value(expr->srcs[0], instr, 1, num_components,
- swizzle, state))
+ if (!match_alu_src(expr->srcs[0], instr, 1, num_components,
+ swizzle, state))
return false;
- return match_value(expr->srcs[1], instr, 0, num_components,
- swizzle, state);
+ return match_alu_src(expr->srcs[1], instr, 0, num_components,
+ swizzle, state);
} else {
return false;
}
diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h
index 94a71635de..35a39ae9e2 100644
--- a/src/compiler/nir/nir_search.h
+++ b/src/compiler/nir/nir_search.h
@@ -75,8 +75,7 @@ typedef struct {
* used for 'is_constant' variables to require, for example, power-of-two
* in order for the search to match.
*/
- bool (*cond)(nir_alu_instr *instr, unsigned src,
- unsigned num_components, const uint8_t *swizzle);
+ bool (*cond)(nir_ssa_def *def, unsigned num_components, const uint8_t *swizzle);
} nir_search_variable;
typedef struct {
diff --git a/src/compiler/nir/nir_search_helpers.h b/src/compiler/nir/nir_search_helpers.h
index 200f2471f8..cf0bae4058 100644
--- a/src/compiler/nir/nir_search_helpers.h
+++ b/src/compiler/nir/nir_search_helpers.h
@@ -36,89 +36,86 @@ __is_power_of_two(unsigned int x)
}
static inline bool
-is_pos_power_of_two(nir_alu_instr *instr, unsigned src, unsigned num_components,
+is_pos_power_of_two(nir_ssa_def *def, unsigned num_components,
const uint8_t *swizzle)
{
- nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
+ nir_const_value *val = nir_ssa_def_as_const_value(def);
/* only constant srcs: */
if (!val)
return false;
for (unsigned i = 0; i < num_components; i++) {
- switch (nir_op_infos[instr->op].input_types[src]) {
- case nir_type_int:
- if (val->i32[swizzle[i]] < 0)
- return false;
- if (!__is_power_of_two(val->i32[swizzle[i]]))
- return false;
- break;
- case nir_type_uint:
- if (!__is_power_of_two(val->u32[swizzle[i]]))
- return false;
- break;
- default:
+ if (val->i32[swizzle[i]] < 0)
+ return false;
+ if (!__is_power_of_two(val->i32[swizzle[i]]))
return false;
- }
}
return true;
}
static inline bool
-is_neg_power_of_two(nir_alu_instr *instr, unsigned src, unsigned num_components,
+is_power_of_two(nir_ssa_def *def, unsigned num_components,
+ const uint8_t *swizzle)
+{
+ nir_const_value *val = nir_ssa_def_as_const_value(def);
+
+ /* only constant srcs: */
+ if (!val)
+ return false;
+
+ for (unsigned i = 0; i < num_components; i++) {
+ if (!__is_power_of_two(val->u32[swizzle[i]]))
+ return false;
+ }
+
+ return true;
+}
+
+
+static inline bool
+is_neg_power_of_two(nir_ssa_def *def, unsigned num_components,
const uint8_t *swizzle)
{
- nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
+ nir_const_value *val = nir_ssa_def_as_const_value(def);
/* only constant srcs: */
if (!val)
return false;
for (unsigned i = 0; i < num_components; i++) {
- switch (nir_op_infos[instr->op].input_types[src]) {
- case nir_type_int:
- if (val->i32[swizzle[i]] > 0)
- return false;
- if (!__is_power_of_two(abs(val->i32[swizzle[i]])))
- return false;
- break;
- default:
+ if (val->i32[swizzle[i]] > 0)
+ return false;
+ if (!__is_power_of_two(abs(val->i32[swizzle[i]])))
return false;
- }
}
return true;
}
static inline bool
-is_zero_to_one(nir_alu_instr *instr, unsigned src, unsigned num_components,
+is_zero_to_one(nir_ssa_def *def, unsigned num_components,
const uint8_t *swizzle)
{
- nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
+ nir_const_value *val = nir_ssa_def_as_const_value(def);
if (!val)
return false;
for (unsigned i = 0; i < num_components; i++) {
- switch (nir_op_infos[instr->op].input_types[src]) {
- case nir_type_float:
- if (val->f32[swizzle[i]] < 0.0f || val->f32[swizzle[i]] > 1.0f)
- return false;
- break;
- default:
+ if (val->f32[swizzle[i]] < 0.0f || val->f32[swizzle[i]] > 1.0f)
return false;
- }
}
return true;
}
static inline bool
-is_not_const(nir_alu_instr *instr, unsigned src, unsigned num_components,
+is_not_const(nir_ssa_def *def, unsigned num_components,
const uint8_t *swizzle)
{
- nir_const_value *val = nir_src_as_const_value(instr->src[src].src);
+ nir_const_value *val = nir_ssa_def_as_const_value(def);
if (val)
return false;