diff options
author | Andrey Tuganov <andreyt@google.com> | 2017-09-06 14:30:27 -0400 |
---|---|---|
committer | David Neto <dneto@google.com> | 2017-09-08 11:08:41 -0400 |
commit | c6dfc11880653f8291f92a20e8a464dd051ca50d (patch) | |
tree | b40175db493b91d4700275eda6415098771f51a0 /source/val | |
parent | 44421022472e8f4de4cb6dc7fde99bf1a6a0fee1 (diff) |
Add new checks to validate arithmetics pass
New operations:
- OpDot
- OpVectorTimesScalar
- OpMatrixTimesScalar
- OpVectorTimesMatrix
- OpMatrixTimesVector
- OpMatrixTimesMatrix
- OpOuterProduct
Diffstat (limited to 'source/val')
-rw-r--r-- | source/val/validation_state.cpp | 39 | ||||
-rw-r--r-- | source/val/validation_state.h | 25 |
2 files changed, 56 insertions, 8 deletions
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 06f9d989..c7282acf 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -591,4 +591,43 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const { return false; } +bool ValidationState_t::IsFloatMatrixType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeMatrix) { + return IsFloatScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::GetMatrixTypeInfo( + uint32_t id, uint32_t* num_rows, uint32_t* num_cols, + uint32_t* column_type, uint32_t* component_type) const { + if (!id) + return false; + + const Instruction* mat_inst = FindDef(id); + assert(mat_inst); + if (mat_inst->opcode() != SpvOpTypeMatrix) + return false; + + const uint32_t vec_type = mat_inst->word(2); + const Instruction* vec_inst = FindDef(vec_type); + assert(vec_inst); + + if (vec_inst->opcode() != SpvOpTypeVector) { + assert(0); + return false; + } + + *num_cols = mat_inst->word(3); + *num_rows = vec_inst->word(3); + *column_type = mat_inst->word(2); + *component_type = vec_inst->word(2); + + return true; +} + } /// namespace libspirv diff --git a/source/val/validation_state.h b/source/val/validation_state.h index b3340656..49eda956 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -333,27 +333,36 @@ class ValidationState_t { // Returns type_id of the scalar component of |id|. // |id| can be either - // - vector type - // - matrix type - // - object of either vector or matrix type + // - scalar, vector or matrix type + // - object of either scalar, vector or matrix type uint32_t GetComponentType(uint32_t id) const; - // Returns dimension of scalar, vector or matrix type or object. Will invoke - // assertion and return 0 if |id| is none of the above. - // In case of matrix returns number of columns. + // Returns + // - 1 for scalar types or objects + // - vector size for vector types or objects + // - num columns for matrix types or objects + // Should not be called with any other arguments (will return zero and invoke + // assertion). uint32_t GetDimension(uint32_t id) const; // Returns bit width of scalar or component. // |id| can be - // - scalar type or object - // - vector or matrix type or object + // - scalar, vector or matrix type + // - object of either scalar, vector or matrix type // Will invoke assertion and return 0 if |id| is none of the above. uint32_t GetBitWidth(uint32_t id) const; + // Provides detailed information on matrix type. + // Returns false iff |id| is not matrix type. + bool GetMatrixTypeInfo( + uint32_t id, uint32_t* num_rows, uint32_t* num_cols, + uint32_t* column_type, uint32_t* component_type) const; + // Returns true iff |id| is a type corresponding to the name of the function. // Only works for types not for objects. bool IsFloatScalarType(uint32_t id) const; bool IsFloatVectorType(uint32_t id) const; + bool IsFloatMatrixType(uint32_t id) const; bool IsIntScalarType(uint32_t id) const; bool IsIntVectorType(uint32_t id) const; bool IsUnsignedIntScalarType(uint32_t id) const; |