summaryrefslogtreecommitdiff
path: root/src/compiler/nir/nir_search.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/compiler/nir/nir_search.c')
-rw-r--r--src/compiler/nir/nir_search.c139
1 files changed, 75 insertions, 64 deletions
diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c
index dec56fee74..c9fcef2832 100644
--- a/src/compiler/nir/nir_search.c
+++ b/src/compiler/nir/nir_search.c
@@ -42,25 +42,24 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
+static bool src_is_type(nir_src src, nir_alu_type type);
+
/**
* Check if a source produces a value of the given type.
*
* Used for satisfying 'a@type' constraints.
*/
static bool
-src_is_type(nir_src src, nir_alu_type type)
+value_is_type(nir_ssa_def *def, nir_alu_type type)
{
assert(type != nir_type_invalid);
- if (!src.is_ssa)
- return false;
-
/* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
if (nir_alu_type_get_base_type(type) == nir_type_bool)
type = nir_type_bool;
- if (src.ssa->parent_instr->type == nir_instr_type_alu) {
- nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
+ if (def->parent_instr->type == nir_instr_type_alu) {
+ nir_alu_instr *src_alu = nir_instr_as_alu(def->parent_instr);
nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
if (type == nir_type_bool) {
@@ -78,8 +77,8 @@ src_is_type(nir_src src, nir_alu_type type)
}
return nir_alu_type_get_base_type(output_type) == type;
- } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
- nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
+ } else if (def->parent_instr->type == nir_instr_type_intrinsic) {
+ nir_intrinsic_instr *intr = nir_instr_as_intrinsic(def->parent_instr);
if (type == nir_type_bool) {
return intr->intrinsic == nir_intrinsic_load_front_face ||
@@ -92,83 +91,61 @@ src_is_type(nir_src src, nir_alu_type type)
}
static bool
-match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
- unsigned num_components, const uint8_t *swizzle,
- struct match_state *state)
+src_is_type(nir_src src, nir_alu_type type)
{
- uint8_t new_swizzle[4];
-
- /* Searching only works on SSA values because, if it's not SSA, we can't
- * know if the value changed between one instance of that value in the
- * expression and another. Also, the replace operation will place reads of
- * that value right before the last instruction in the expression we're
- * replacing so those reads will happen after the original reads and may
- * not be valid if they're register reads.
- */
- if (!instr->src[src].src.is_ssa)
+ if (!src.is_ssa)
return false;
- /* If the source is an explicitly sized source, then we need to reset
- * both the number of components and the swizzle.
- */
- if (nir_op_infos[instr->op].input_sizes[src] != 0) {
- num_components = nir_op_infos[instr->op].input_sizes[src];
- swizzle = identity_swizzle;
- }
-
- for (unsigned i = 0; i < num_components; ++i)
- new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
-
- /* If the value has a specific bit size and it doesn't match, bail */
- if (value->bit_size &&
- nir_src_bit_size(instr->src[src].src) != value->bit_size)
- return false;
+ return value_is_type(src.ssa, type);
+}
+static bool
+match_value(const nir_search_value *value, nir_ssa_def *def,
+ unsigned num_components, const uint8_t *swizzle,
+ struct match_state *state)
+{
switch (value->type) {
case nir_search_value_expression:
- if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
+ if (def->parent_instr->type != nir_instr_type_alu)
return false;
return match_expression(nir_search_value_as_expression(value),
- nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
- num_components, new_swizzle, state);
+ nir_instr_as_alu(def->parent_instr),
+ num_components, swizzle, state);
case nir_search_value_variable: {
nir_search_variable *var = nir_search_value_as_variable(value);
assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
if (state->variables_seen & (1 << var->variable)) {
- if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
+ if (state->variables[var->variable].src.ssa != def)
return false;
- assert(!instr->src[src].abs && !instr->src[src].negate);
-
for (unsigned i = 0; i < num_components; ++i) {
- if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
+ if (state->variables[var->variable].swizzle[i] != swizzle[i])
return false;
}
return true;
} else {
- if (var->is_constant &&
- instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
+ if (var->is_constant && def->parent_instr->type != nir_instr_type_load_const)
return false;
- if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
+ if (var->cond && !var->cond(def, num_components, swizzle))
return false;
if (var->type != nir_type_invalid &&
- !src_is_type(instr->src[src].src, var->type))
+ !value_is_type(def, var->type))
return false;
state->variables_seen |= (1 << var->variable);
- state->variables[var->variable].src = instr->src[src].src;
+ state->variables[var->variable].src.ssa = def;
state->variables[var->variable].abs = false;
state->variables[var->variable].negate = false;
for (unsigned i = 0; i < 4; ++i) {
if (i < num_components)
- state->variables[var->variable].swizzle[i] = new_swizzle[i];
+ state->variables[var->variable].swizzle[i] = swizzle[i];
else
state->variables[var->variable].swizzle[i] = 0;
}
@@ -180,14 +157,11 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
case nir_search_value_constant: {
nir_search_constant *const_val = nir_search_value_as_constant(value);
- if (!instr->src[src].src.is_ssa)
- return false;
-
- if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
+ if (def->parent_instr->type != nir_instr_type_load_const)
return false;
nir_load_const_instr *load =
- nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
+ nir_instr_as_load_const(def->parent_instr);
switch (const_val->type) {
case nir_type_float:
@@ -195,10 +169,10 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
double val;
switch (load->def.bit_size) {
case 32:
- val = load->value.f32[new_swizzle[i]];
+ val = load->value.f32[swizzle[i]];
break;
case 64:
- val = load->value.f64[new_swizzle[i]];
+ val = load->value.f64[swizzle[i]];
break;
default:
unreachable("unknown bit size");
@@ -215,7 +189,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
switch (load->def.bit_size) {
case 32:
for (unsigned i = 0; i < num_components; ++i) {
- if (load->value.u32[new_swizzle[i]] !=
+ if (load->value.u32[swizzle[i]] !=
(uint32_t)const_val->data.u)
return false;
}
@@ -223,7 +197,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
case 64:
for (unsigned i = 0; i < num_components; ++i) {
- if (load->value.u64[new_swizzle[i]] != const_val->data.u)
+ if (load->value.u64[swizzle[i]] != const_val->data.u)
return false;
}
return true;
@@ -243,6 +217,43 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
}
static bool
+match_alu_src(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
+ unsigned num_components, const uint8_t *swizzle,
+ struct match_state *state)
+{
+ uint8_t new_swizzle[4];
+
+ /* Searching only works on SSA values because, if it's not SSA, we can't
+ * know if the value changed between one instance of that value in the
+ * expression and another. Also, the replace operation will place reads of
+ * that value right before the last instruction in the expression we're
+ * replacing so those reads will happen after the original reads and may
+ * not be valid if they're register reads.
+ */
+ if (!instr->src[src].src.is_ssa)
+ return false;
+
+ /* If the source is an explicitly sized source, then we need to reset
+ * both the number of components and the swizzle.
+ */
+ if (nir_op_infos[instr->op].input_sizes[src] != 0) {
+ num_components = nir_op_infos[instr->op].input_sizes[src];
+ swizzle = identity_swizzle;
+ }
+
+ for (unsigned i = 0; i < num_components; ++i)
+ new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
+
+ /* If the value has a specific bit size and it doesn't match, bail */
+ if (value->bit_size &&
+ nir_src_bit_size(instr->src[src].src) != value->bit_size)
+ return false;
+
+ return match_value(value, instr->src[src].src.ssa, num_components, swizzle,
+ state);
+}
+
+static bool
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state)
@@ -287,8 +298,8 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
bool matched = true;
for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
- if (!match_value(expr->srcs[i], instr, i, num_components,
- swizzle, state)) {
+ if (!match_alu_src(expr->srcs[i], instr, i, num_components,
+ swizzle, state)) {
matched = false;
break;
}
@@ -306,12 +317,12 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
*/
state->variables_seen = variables_seen_stash;
- if (!match_value(expr->srcs[0], instr, 1, num_components,
- swizzle, state))
+ if (!match_alu_src(expr->srcs[0], instr, 1, num_components,
+ swizzle, state))
return false;
- return match_value(expr->srcs[1], instr, 0, num_components,
- swizzle, state);
+ return match_alu_src(expr->srcs[1], instr, 0, num_components,
+ swizzle, state);
} else {
return false;
}