summaryrefslogtreecommitdiff
path: root/source/val
diff options
context:
space:
mode:
authorAndrey Tuganov <andreyt@google.com>2017-09-06 14:30:27 -0400
committerDavid Neto <dneto@google.com>2017-09-08 11:08:41 -0400
commitc6dfc11880653f8291f92a20e8a464dd051ca50d (patch)
treeb40175db493b91d4700275eda6415098771f51a0 /source/val
parent44421022472e8f4de4cb6dc7fde99bf1a6a0fee1 (diff)
Add new checks to validate arithmetics pass
New operations: - OpDot - OpVectorTimesScalar - OpMatrixTimesScalar - OpVectorTimesMatrix - OpMatrixTimesVector - OpMatrixTimesMatrix - OpOuterProduct
Diffstat (limited to 'source/val')
-rw-r--r--source/val/validation_state.cpp39
-rw-r--r--source/val/validation_state.h25
2 files changed, 56 insertions, 8 deletions
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 06f9d989..c7282acf 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -591,4 +591,43 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
return false;
}
+bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ assert(inst);
+
+ if (inst->opcode() == SpvOpTypeMatrix) {
+ return IsFloatScalarType(GetComponentType(id));
+ }
+
+ return false;
+}
+
+bool ValidationState_t::GetMatrixTypeInfo(
+ uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
+ uint32_t* column_type, uint32_t* component_type) const {
+ if (!id)
+ return false;
+
+ const Instruction* mat_inst = FindDef(id);
+ assert(mat_inst);
+ if (mat_inst->opcode() != SpvOpTypeMatrix)
+ return false;
+
+ const uint32_t vec_type = mat_inst->word(2);
+ const Instruction* vec_inst = FindDef(vec_type);
+ assert(vec_inst);
+
+ if (vec_inst->opcode() != SpvOpTypeVector) {
+ assert(0);
+ return false;
+ }
+
+ *num_cols = mat_inst->word(3);
+ *num_rows = vec_inst->word(3);
+ *column_type = mat_inst->word(2);
+ *component_type = vec_inst->word(2);
+
+ return true;
+}
+
} /// namespace libspirv
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index b3340656..49eda956 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -333,27 +333,36 @@ class ValidationState_t {
// Returns type_id of the scalar component of |id|.
// |id| can be either
- // - vector type
- // - matrix type
- // - object of either vector or matrix type
+ // - scalar, vector or matrix type
+ // - object of either scalar, vector or matrix type
uint32_t GetComponentType(uint32_t id) const;
- // Returns dimension of scalar, vector or matrix type or object. Will invoke
- // assertion and return 0 if |id| is none of the above.
- // In case of matrix returns number of columns.
+ // Returns
+ // - 1 for scalar types or objects
+ // - vector size for vector types or objects
+ // - num columns for matrix types or objects
+ // Should not be called with any other arguments (will return zero and invoke
+ // assertion).
uint32_t GetDimension(uint32_t id) const;
// Returns bit width of scalar or component.
// |id| can be
- // - scalar type or object
- // - vector or matrix type or object
+ // - scalar, vector or matrix type
+ // - object of either scalar, vector or matrix type
// Will invoke assertion and return 0 if |id| is none of the above.
uint32_t GetBitWidth(uint32_t id) const;
+ // Provides detailed information on matrix type.
+ // Returns false iff |id| is not matrix type.
+ bool GetMatrixTypeInfo(
+ uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
+ uint32_t* column_type, uint32_t* component_type) const;
+
// Returns true iff |id| is a type corresponding to the name of the function.
// Only works for types not for objects.
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
+ bool IsFloatMatrixType(uint32_t id) const;
bool IsIntScalarType(uint32_t id) const;
bool IsIntVectorType(uint32_t id) const;
bool IsUnsignedIntScalarType(uint32_t id) const;