diff options
-rw-r--r-- | source/comp/markv_codec.cpp | 4 | ||||
-rw-r--r-- | source/val/validation_state.cpp | 39 | ||||
-rw-r--r-- | source/val/validation_state.h | 25 | ||||
-rw-r--r-- | source/validate_arithmetics.cpp | 281 | ||||
-rw-r--r-- | test/comp/markv_codec_test.cpp | 2 | ||||
-rw-r--r-- | test/val/val_arithmetics_test.cpp | 564 |
6 files changed, 904 insertions, 11 deletions
diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp index c99fdd99..b3f0fffb 100644 --- a/source/comp/markv_codec.cpp +++ b/source/comp/markv_codec.cpp @@ -1290,8 +1290,10 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() { } case SpvOpVectorTimesScalar: { - if (operand_index_ == 0) + if (operand_index_ == 0) { + // TODO(atgoo@github.com) Could be narrowed to vector of floats. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector); + } assert(inst_.type_id); if (operand_index_ == 2) diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp index 06f9d989..c7282acf 100644 --- a/source/val/validation_state.cpp +++ b/source/val/validation_state.cpp @@ -591,4 +591,43 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const { return false; } +bool ValidationState_t::IsFloatMatrixType(uint32_t id) const { + const Instruction* inst = FindDef(id); + assert(inst); + + if (inst->opcode() == SpvOpTypeMatrix) { + return IsFloatScalarType(GetComponentType(id)); + } + + return false; +} + +bool ValidationState_t::GetMatrixTypeInfo( + uint32_t id, uint32_t* num_rows, uint32_t* num_cols, + uint32_t* column_type, uint32_t* component_type) const { + if (!id) + return false; + + const Instruction* mat_inst = FindDef(id); + assert(mat_inst); + if (mat_inst->opcode() != SpvOpTypeMatrix) + return false; + + const uint32_t vec_type = mat_inst->word(2); + const Instruction* vec_inst = FindDef(vec_type); + assert(vec_inst); + + if (vec_inst->opcode() != SpvOpTypeVector) { + assert(0); + return false; + } + + *num_cols = mat_inst->word(3); + *num_rows = vec_inst->word(3); + *column_type = mat_inst->word(2); + *component_type = vec_inst->word(2); + + return true; +} + } /// namespace libspirv diff --git a/source/val/validation_state.h b/source/val/validation_state.h index b3340656..49eda956 100644 --- a/source/val/validation_state.h +++ b/source/val/validation_state.h @@ -333,27 +333,36 @@ class ValidationState_t { // Returns type_id of the scalar component of |id|. // |id| can be either - // - vector type - // - matrix type - // - object of either vector or matrix type + // - scalar, vector or matrix type + // - object of either scalar, vector or matrix type uint32_t GetComponentType(uint32_t id) const; - // Returns dimension of scalar, vector or matrix type or object. Will invoke - // assertion and return 0 if |id| is none of the above. - // In case of matrix returns number of columns. + // Returns + // - 1 for scalar types or objects + // - vector size for vector types or objects + // - num columns for matrix types or objects + // Should not be called with any other arguments (will return zero and invoke + // assertion). uint32_t GetDimension(uint32_t id) const; // Returns bit width of scalar or component. // |id| can be - // - scalar type or object - // - vector or matrix type or object + // - scalar, vector or matrix type + // - object of either scalar, vector or matrix type // Will invoke assertion and return 0 if |id| is none of the above. uint32_t GetBitWidth(uint32_t id) const; + // Provides detailed information on matrix type. + // Returns false iff |id| is not matrix type. + bool GetMatrixTypeInfo( + uint32_t id, uint32_t* num_rows, uint32_t* num_cols, + uint32_t* column_type, uint32_t* component_type) const; + // Returns true iff |id| is a type corresponding to the name of the function. // Only works for types not for objects. bool IsFloatScalarType(uint32_t id) const; bool IsFloatVectorType(uint32_t id) const; + bool IsFloatMatrixType(uint32_t id) const; bool IsIntScalarType(uint32_t id) const; bool IsIntVectorType(uint32_t id) const; bool IsUnsignedIntScalarType(uint32_t id) const; diff --git a/source/validate_arithmetics.cpp b/source/validate_arithmetics.cpp index f3f83186..3ac5c222 100644 --- a/source/validate_arithmetics.cpp +++ b/source/validate_arithmetics.cpp @@ -126,6 +126,287 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, break; } + case SpvOpDot: { + if (!_.IsFloatScalarType(inst->type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float scalar type as type_id: " + << spvOpcodeString(opcode); + + uint32_t first_vector_num_components = 0; + + for (size_t operand_index = 2; operand_index < inst->num_operands; + ++operand_index) { + const uint32_t type_id = + _.GetTypeId(GetOperandWord(inst, operand_index)); + + if (!type_id || !_.IsFloatVectorType(type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector as operand: " + << spvOpcodeString(opcode) << " operand index " << operand_index; + + + const uint32_t component_type = _.GetComponentType(type_id); + if (component_type != inst->type_id) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected component type to be equal to type_id: " + << spvOpcodeString(opcode) << " operand index " << operand_index; + + const uint32_t num_components = _.GetDimension(type_id); + if (operand_index == 2) { + first_vector_num_components = num_components; + } else if (num_components != first_vector_num_components) { + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected operands to have the same number of componenets: " + << spvOpcodeString(opcode); + } + } + break; + } + + case SpvOpVectorTimesScalar: { + if (!_.IsFloatVectorType(inst->type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector type as type_id: " + << spvOpcodeString(opcode); + + const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2)); + if (inst->type_id != vector_type_id) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected vector operand type to be equal to type_id: " + << spvOpcodeString(opcode); + + const uint32_t component_type = _.GetComponentType(vector_type_id); + + const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3)); + if (component_type != scalar_type_id) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected scalar operand type to be equal to the component " + << "type of the vector operand: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpMatrixTimesScalar: { + if (!_.IsFloatMatrixType(inst->type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as type_id: " + << spvOpcodeString(opcode); + + const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2)); + if (inst->type_id != matrix_type_id) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected matrix operand type to be equal to type_id: " + << spvOpcodeString(opcode); + + const uint32_t component_type = _.GetComponentType(matrix_type_id); + + const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3)); + if (component_type != scalar_type_id) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected scalar operand type to be equal to the component " + << "type of the matrix operand: " + << spvOpcodeString(opcode); + + break; + } + + case SpvOpVectorTimesMatrix: { + const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2)); + const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 3)); + + if (!_.IsFloatVectorType(inst->type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector type as type_id: " + << spvOpcodeString(opcode); + + const uint32_t res_component_type = _.GetComponentType(inst->type_id); + + if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector type as left operand: " + << spvOpcodeString(opcode); + + if (res_component_type != _.GetComponentType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected component types of type_id and vector to be equal: " + << spvOpcodeString(opcode); + + uint32_t matrix_num_rows = 0; + uint32_t matrix_num_cols = 0; + uint32_t matrix_col_type = 0; + uint32_t matrix_component_type = 0; + if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, + &matrix_num_cols, &matrix_col_type, + &matrix_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as right operand: " + << spvOpcodeString(opcode); + + if (res_component_type != matrix_component_type) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected component types of type_id and matrix to be equal: " + << spvOpcodeString(opcode); + + if (matrix_num_cols != _.GetDimension(inst->type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected number of columns of the matrix to be equal to the " + << "type_id vector size: " << spvOpcodeString(opcode); + + if (matrix_num_rows != _.GetDimension(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected number of rows of the matrix to be equal to the " + << "vector operand size: " << spvOpcodeString(opcode); + + break; + } + + case SpvOpMatrixTimesVector: { + const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2)); + const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 3)); + + if (!_.IsFloatVectorType(inst->type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector type as type_id: " + << spvOpcodeString(opcode); + + uint32_t matrix_num_rows = 0; + uint32_t matrix_num_cols = 0; + uint32_t matrix_col_type = 0; + uint32_t matrix_component_type = 0; + if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows, + &matrix_num_cols, &matrix_col_type, + &matrix_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as left operand: " + << spvOpcodeString(opcode); + + if (inst->type_id != matrix_col_type) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected column type of the matrix to be equal to type_id: " + << spvOpcodeString(opcode); + + if (!vector_type_id || !_.IsFloatVectorType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector type as right operand: " + << spvOpcodeString(opcode); + + if (matrix_component_type != _.GetComponentType(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected component types of the operands to be equal: " + << spvOpcodeString(opcode); + + if (matrix_num_cols != _.GetDimension(vector_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected number of columns of the matrix to be equal to the " + << "vector size: " << spvOpcodeString(opcode); + + break; + } + + case SpvOpMatrixTimesMatrix: { + const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2)); + const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3)); + + uint32_t res_num_rows = 0; + uint32_t res_num_cols = 0; + uint32_t res_col_type = 0; + uint32_t res_component_type = 0; + if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols, + &res_col_type, &res_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as type_id: " + << spvOpcodeString(opcode); + + uint32_t left_num_rows = 0; + uint32_t left_num_cols = 0; + uint32_t left_col_type = 0; + uint32_t left_component_type = 0; + if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols, + &left_col_type, &left_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as left operand: " + << spvOpcodeString(opcode); + + uint32_t right_num_rows = 0; + uint32_t right_num_cols = 0; + uint32_t right_col_type = 0; + uint32_t right_component_type = 0; + if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols, + &right_col_type, &right_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as right operand: " + << spvOpcodeString(opcode); + + if (!_.IsFloatScalarType(res_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as type_id: " + << spvOpcodeString(opcode); + + if (res_col_type != left_col_type) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected column types of type_id and left matrix to be " + << "equal: " << spvOpcodeString(opcode); + + if (res_component_type != right_component_type) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected component types of type_id and right matrix to be " + << "equal: " << spvOpcodeString(opcode); + + if (res_num_cols != right_num_cols) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected number of columns of type_id and right matrix to be " + << "equal: " << spvOpcodeString(opcode); + + if (left_num_cols != right_num_rows) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected number of columns of left matrix and number of rows " + << "of right matrix to be equal: " << spvOpcodeString(opcode); + + assert(left_num_rows == res_num_rows); + break; + } + + case SpvOpOuterProduct: { + const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2)); + const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3)); + + uint32_t res_num_rows = 0; + uint32_t res_num_cols = 0; + uint32_t res_col_type = 0; + uint32_t res_component_type = 0; + if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols, + &res_col_type, &res_component_type)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float matrix type as type_id: " + << spvOpcodeString(opcode); + + if (left_type_id != res_col_type) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected column type of the type_id to be equal to the type " + << "of the left operand: " + << spvOpcodeString(opcode); + + if (!right_type_id || !_.IsFloatVectorType(right_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected float vector type as right operand: " + << spvOpcodeString(opcode); + + if (res_component_type != _.GetComponentType(right_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected component types of the operands to be equal: " + << spvOpcodeString(opcode); + + if (res_num_cols != _.GetDimension(right_type_id)) + return _.diag(SPV_ERROR_INVALID_DATA) + << "Expected number of columns of the matrix to be equal to the " + << "vector size of the right operand: " << spvOpcodeString(opcode); + + break; + } + + // TODO(atgoo@github.com): Support other operations. + default: break; } diff --git a/test/comp/markv_codec_test.cpp b/test/comp/markv_codec_test.cpp index 246e6cb9..711069ca 100644 --- a/test/comp/markv_codec_test.cpp +++ b/test/comp/markv_codec_test.cpp @@ -682,8 +682,6 @@ TEST(Markv, VectorTimesScalar) { %f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0 %res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2 %res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2 -%res3 = OpVectorTimesScalar %u32vec3 %u32vec3_012 %u32_2 -%res4 = OpVectorTimesScalar %s32vec2 %s32vec2_01 %s32_2 )"); } diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp index d5fe8310..6bb568b7 100644 --- a/test/val/val_arithmetics_test.cpp +++ b/test/val/val_arithmetics_test.cpp @@ -35,6 +35,7 @@ R"( OpCapability Shader OpCapability Int64 OpCapability Float64 +OpCapability Matrix %ext_inst = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 OpEntryPoint Fragment %main "main" @@ -66,6 +67,12 @@ OpEntryPoint Fragment %main "main" %f32vec4 = OpTypeVector %f32 4 %f64vec4 = OpTypeVector %f64 4 +%f32mat22 = OpTypeMatrix %f32vec2 2 +%f32mat23 = OpTypeMatrix %f32vec2 3 +%f32mat32 = OpTypeMatrix %f32vec3 2 +%f32mat33 = OpTypeMatrix %f32vec3 3 +%f64mat22 = OpTypeMatrix %f64vec2 2 + %f32_0 = OpConstant %f32 0 %f32_1 = OpConstant %f32 1 %f32_2 = OpConstant %f32 2 @@ -133,6 +140,13 @@ OpEntryPoint Fragment %main "main" %f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3 %f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4 +%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12 +%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12 +%f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123 +%f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123 + +%f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12 + %main = OpFunction %void None %func %main_entry = OpLabel)"; @@ -536,4 +550,554 @@ TEST_F(ValidateArithmetics, UDivWrongOperand2) { "UDiv operand index 3")); } +TEST_F(ValidateArithmetics, DotSuccess) { + const std::string body = R"( +%val = OpDot %f32 %f32vec2_01 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, DotWrongTypeId) { + const std::string body = R"( +%val = OpDot %u32 %u32vec2_01 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float scalar type as type_id: Dot")); +} + +TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) { + const std::string body = R"( +%val = OpDot %f32 %f32 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector as operand: Dot operand index 2")); +} + +TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) { + const std::string body = R"( +%val = OpDot %f32 %f32vec3_012 %f32_1 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector as operand: Dot operand index 3")); +} + +TEST_F(ValidateArithmetics, DotWrongComponentOperand1) { + const std::string body = R"( +%val = OpDot %f64 %f32vec2_01 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component type to be equal to type_id: Dot operand index 2")); +} + +TEST_F(ValidateArithmetics, DotWrongComponentOperand2) { + const std::string body = R"( +%val = OpDot %f32 %f32vec2_01 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component type to be equal to type_id: Dot operand index 3")); +} + +TEST_F(ValidateArithmetics, DotDifferentVectorSize) { + const std::string body = R"( +%val = OpDot %f32 %f32vec2_01 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected operands to have the same number of componenets: Dot")); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) { + const std::string body = R"( +%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) { + const std::string body = R"( +%val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector type as type_id: " + "VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) { + const std::string body = R"( +%val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected vector operand type to be equal to type_id: " + "VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) { + const std::string body = R"( +%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected scalar operand type to be equal to the component " + "type of the vector operand: VectorTimesScalar")); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as type_id: " + "MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected matrix operand type to be equal to type_id: " + "MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) { + const std::string body = R"( +%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected scalar operand type to be equal to the component " + "type of the matrix operand: MatrixTimesScalar")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector type as type_id: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector type as left operand: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component types of type_id and vector to be equal: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as right operand: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component types of type_id and matrix to be equal: " + "VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected number of columns of the matrix to be equal to the type_id " + "vector size: VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) { + const std::string body = R"( +%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected number of rows of the matrix to be equal to the vector " + "operand size: VectorTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector type as type_id: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as left operand: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected column type of the matrix to be equal to type_id: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector type as right operand: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component types of the operands to be equal: " + "MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) { + const std::string body = R"( +%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected number of columns of the matrix to be equal to the vector " + "size: MatrixTimesVector")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as type_id: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as left operand: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as right operand: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected column types of type_id and left matrix to be equal: " + "MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component types of type_id and right matrix to be equal: " + "MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected number of columns of type_id and right matrix to be equal: " + "MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) { + const std::string body = R"( +%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected number of columns of left matrix and number of rows of right " + "matrix to be equal: MatrixTimesMatrix")); +} + +TEST_F(ValidateArithmetics, OuterProduct2x2Success) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, OuterProduct3x2Success) { + const std::string body = R"( +%val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, OuterProduct2x3Success) { + const std::string body = R"( +%val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateArithmetics, OuterProductWrongTypeId) { + const std::string body = R"( +%val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float matrix type as type_id: " + "OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected column type of the type_id to be equal to the type " + "of the left operand: OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected float vector type as right operand: OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected component types of the operands to be equal: OuterProduct")); +} + +TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) { + const std::string body = R"( +%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123 +)"; + + CompileSuccessfully(GenerateCode(body).c_str()); + ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), HasSubstr( + "Expected number of columns of the matrix to be equal to the " + "vector size of the right operand: OuterProduct")); +} + } // anonymous namespace |