summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source/comp/markv_codec.cpp4
-rw-r--r--source/val/validation_state.cpp39
-rw-r--r--source/val/validation_state.h25
-rw-r--r--source/validate_arithmetics.cpp281
-rw-r--r--test/comp/markv_codec_test.cpp2
-rw-r--r--test/val/val_arithmetics_test.cpp564
6 files changed, 904 insertions, 11 deletions
diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp
index c99fdd99..b3f0fffb 100644
--- a/source/comp/markv_codec.cpp
+++ b/source/comp/markv_codec.cpp
@@ -1290,8 +1290,10 @@ uint64_t MarkvCodecBase::GetRuleBasedMtf() {
}
case SpvOpVectorTimesScalar: {
- if (operand_index_ == 0)
+ if (operand_index_ == 0) {
+ // TODO(atgoo@github.com) Could be narrowed to vector of floats.
return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
+ }
assert(inst_.type_id);
if (operand_index_ == 2)
diff --git a/source/val/validation_state.cpp b/source/val/validation_state.cpp
index 06f9d989..c7282acf 100644
--- a/source/val/validation_state.cpp
+++ b/source/val/validation_state.cpp
@@ -591,4 +591,43 @@ bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
return false;
}
+bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
+ const Instruction* inst = FindDef(id);
+ assert(inst);
+
+ if (inst->opcode() == SpvOpTypeMatrix) {
+ return IsFloatScalarType(GetComponentType(id));
+ }
+
+ return false;
+}
+
+bool ValidationState_t::GetMatrixTypeInfo(
+ uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
+ uint32_t* column_type, uint32_t* component_type) const {
+ if (!id)
+ return false;
+
+ const Instruction* mat_inst = FindDef(id);
+ assert(mat_inst);
+ if (mat_inst->opcode() != SpvOpTypeMatrix)
+ return false;
+
+ const uint32_t vec_type = mat_inst->word(2);
+ const Instruction* vec_inst = FindDef(vec_type);
+ assert(vec_inst);
+
+ if (vec_inst->opcode() != SpvOpTypeVector) {
+ assert(0);
+ return false;
+ }
+
+ *num_cols = mat_inst->word(3);
+ *num_rows = vec_inst->word(3);
+ *column_type = mat_inst->word(2);
+ *component_type = vec_inst->word(2);
+
+ return true;
+}
+
} /// namespace libspirv
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index b3340656..49eda956 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -333,27 +333,36 @@ class ValidationState_t {
// Returns type_id of the scalar component of |id|.
// |id| can be either
- // - vector type
- // - matrix type
- // - object of either vector or matrix type
+ // - scalar, vector or matrix type
+ // - object of either scalar, vector or matrix type
uint32_t GetComponentType(uint32_t id) const;
- // Returns dimension of scalar, vector or matrix type or object. Will invoke
- // assertion and return 0 if |id| is none of the above.
- // In case of matrix returns number of columns.
+ // Returns
+ // - 1 for scalar types or objects
+ // - vector size for vector types or objects
+ // - num columns for matrix types or objects
+ // Should not be called with any other arguments (will return zero and invoke
+ // assertion).
uint32_t GetDimension(uint32_t id) const;
// Returns bit width of scalar or component.
// |id| can be
- // - scalar type or object
- // - vector or matrix type or object
+ // - scalar, vector or matrix type
+ // - object of either scalar, vector or matrix type
// Will invoke assertion and return 0 if |id| is none of the above.
uint32_t GetBitWidth(uint32_t id) const;
+ // Provides detailed information on matrix type.
+ // Returns false iff |id| is not matrix type.
+ bool GetMatrixTypeInfo(
+ uint32_t id, uint32_t* num_rows, uint32_t* num_cols,
+ uint32_t* column_type, uint32_t* component_type) const;
+
// Returns true iff |id| is a type corresponding to the name of the function.
// Only works for types not for objects.
bool IsFloatScalarType(uint32_t id) const;
bool IsFloatVectorType(uint32_t id) const;
+ bool IsFloatMatrixType(uint32_t id) const;
bool IsIntScalarType(uint32_t id) const;
bool IsIntVectorType(uint32_t id) const;
bool IsUnsignedIntScalarType(uint32_t id) const;
diff --git a/source/validate_arithmetics.cpp b/source/validate_arithmetics.cpp
index f3f83186..3ac5c222 100644
--- a/source/validate_arithmetics.cpp
+++ b/source/validate_arithmetics.cpp
@@ -126,6 +126,287 @@ spv_result_t ArithmeticsPass(ValidationState_t& _,
break;
}
+ case SpvOpDot: {
+ if (!_.IsFloatScalarType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float scalar type as type_id: "
+ << spvOpcodeString(opcode);
+
+ uint32_t first_vector_num_components = 0;
+
+ for (size_t operand_index = 2; operand_index < inst->num_operands;
+ ++operand_index) {
+ const uint32_t type_id =
+ _.GetTypeId(GetOperandWord(inst, operand_index));
+
+ if (!type_id || !_.IsFloatVectorType(type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector as operand: "
+ << spvOpcodeString(opcode) << " operand index " << operand_index;
+
+
+ const uint32_t component_type = _.GetComponentType(type_id);
+ if (component_type != inst->type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component type to be equal to type_id: "
+ << spvOpcodeString(opcode) << " operand index " << operand_index;
+
+ const uint32_t num_components = _.GetDimension(type_id);
+ if (operand_index == 2) {
+ first_vector_num_components = num_components;
+ } else if (num_components != first_vector_num_components) {
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected operands to have the same number of componenets: "
+ << spvOpcodeString(opcode);
+ }
+ }
+ break;
+ }
+
+ case SpvOpVectorTimesScalar: {
+ if (!_.IsFloatVectorType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ if (inst->type_id != vector_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected vector operand type to be equal to type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t component_type = _.GetComponentType(vector_type_id);
+
+ const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+ if (component_type != scalar_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected scalar operand type to be equal to the component "
+ << "type of the vector operand: "
+ << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpMatrixTimesScalar: {
+ if (!_.IsFloatMatrixType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ if (inst->type_id != matrix_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected matrix operand type to be equal to type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t component_type = _.GetComponentType(matrix_type_id);
+
+ const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+ if (component_type != scalar_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected scalar operand type to be equal to the component "
+ << "type of the matrix operand: "
+ << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpVectorTimesMatrix: {
+ const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ if (!_.IsFloatVectorType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t res_component_type = _.GetComponentType(inst->type_id);
+
+ if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as left operand: "
+ << spvOpcodeString(opcode);
+
+ if (res_component_type != _.GetComponentType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of type_id and vector to be equal: "
+ << spvOpcodeString(opcode);
+
+ uint32_t matrix_num_rows = 0;
+ uint32_t matrix_num_cols = 0;
+ uint32_t matrix_col_type = 0;
+ uint32_t matrix_component_type = 0;
+ if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
+ &matrix_num_cols, &matrix_col_type,
+ &matrix_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (res_component_type != matrix_component_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of type_id and matrix to be equal: "
+ << spvOpcodeString(opcode);
+
+ if (matrix_num_cols != _.GetDimension(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of the matrix to be equal to the "
+ << "type_id vector size: " << spvOpcodeString(opcode);
+
+ if (matrix_num_rows != _.GetDimension(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of rows of the matrix to be equal to the "
+ << "vector operand size: " << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpMatrixTimesVector: {
+ const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ if (!_.IsFloatVectorType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as type_id: "
+ << spvOpcodeString(opcode);
+
+ uint32_t matrix_num_rows = 0;
+ uint32_t matrix_num_cols = 0;
+ uint32_t matrix_col_type = 0;
+ uint32_t matrix_component_type = 0;
+ if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
+ &matrix_num_cols, &matrix_col_type,
+ &matrix_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as left operand: "
+ << spvOpcodeString(opcode);
+
+ if (inst->type_id != matrix_col_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected column type of the matrix to be equal to type_id: "
+ << spvOpcodeString(opcode);
+
+ if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (matrix_component_type != _.GetComponentType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of the operands to be equal: "
+ << spvOpcodeString(opcode);
+
+ if (matrix_num_cols != _.GetDimension(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of the matrix to be equal to the "
+ << "vector size: " << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpMatrixTimesMatrix: {
+ const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ uint32_t res_num_rows = 0;
+ uint32_t res_num_cols = 0;
+ uint32_t res_col_type = 0;
+ uint32_t res_component_type = 0;
+ if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
+ &res_col_type, &res_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ uint32_t left_num_rows = 0;
+ uint32_t left_num_cols = 0;
+ uint32_t left_col_type = 0;
+ uint32_t left_component_type = 0;
+ if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
+ &left_col_type, &left_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as left operand: "
+ << spvOpcodeString(opcode);
+
+ uint32_t right_num_rows = 0;
+ uint32_t right_num_cols = 0;
+ uint32_t right_col_type = 0;
+ uint32_t right_component_type = 0;
+ if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
+ &right_col_type, &right_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (!_.IsFloatScalarType(res_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ if (res_col_type != left_col_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected column types of type_id and left matrix to be "
+ << "equal: " << spvOpcodeString(opcode);
+
+ if (res_component_type != right_component_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of type_id and right matrix to be "
+ << "equal: " << spvOpcodeString(opcode);
+
+ if (res_num_cols != right_num_cols)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of type_id and right matrix to be "
+ << "equal: " << spvOpcodeString(opcode);
+
+ if (left_num_cols != right_num_rows)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of left matrix and number of rows "
+ << "of right matrix to be equal: " << spvOpcodeString(opcode);
+
+ assert(left_num_rows == res_num_rows);
+ break;
+ }
+
+ case SpvOpOuterProduct: {
+ const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ uint32_t res_num_rows = 0;
+ uint32_t res_num_cols = 0;
+ uint32_t res_col_type = 0;
+ uint32_t res_component_type = 0;
+ if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
+ &res_col_type, &res_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ if (left_type_id != res_col_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected column type of the type_id to be equal to the type "
+ << "of the left operand: "
+ << spvOpcodeString(opcode);
+
+ if (!right_type_id || !_.IsFloatVectorType(right_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (res_component_type != _.GetComponentType(right_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of the operands to be equal: "
+ << spvOpcodeString(opcode);
+
+ if (res_num_cols != _.GetDimension(right_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of the matrix to be equal to the "
+ << "vector size of the right operand: " << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ // TODO(atgoo@github.com): Support other operations.
+
default:
break;
}
diff --git a/test/comp/markv_codec_test.cpp b/test/comp/markv_codec_test.cpp
index 246e6cb9..711069ca 100644
--- a/test/comp/markv_codec_test.cpp
+++ b/test/comp/markv_codec_test.cpp
@@ -682,8 +682,6 @@ TEST(Markv, VectorTimesScalar) {
%f32vec4_3210 = OpCompositeConstruct %f32vec4 %f32_3 %f32_2 %f32_1 %f32_0
%res1 = OpVectorTimesScalar %f32vec4 %f32vec4_0123 %f32_2
%res2 = OpVectorTimesScalar %f32vec4 %f32vec4_3210 %f32_2
-%res3 = OpVectorTimesScalar %u32vec3 %u32vec3_012 %u32_2
-%res4 = OpVectorTimesScalar %s32vec2 %s32vec2_01 %s32_2
)");
}
diff --git a/test/val/val_arithmetics_test.cpp b/test/val/val_arithmetics_test.cpp
index d5fe8310..6bb568b7 100644
--- a/test/val/val_arithmetics_test.cpp
+++ b/test/val/val_arithmetics_test.cpp
@@ -35,6 +35,7 @@ R"(
OpCapability Shader
OpCapability Int64
OpCapability Float64
+OpCapability Matrix
%ext_inst = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Fragment %main "main"
@@ -66,6 +67,12 @@ OpEntryPoint Fragment %main "main"
%f32vec4 = OpTypeVector %f32 4
%f64vec4 = OpTypeVector %f64 4
+%f32mat22 = OpTypeMatrix %f32vec2 2
+%f32mat23 = OpTypeMatrix %f32vec2 3
+%f32mat32 = OpTypeMatrix %f32vec3 2
+%f32mat33 = OpTypeMatrix %f32vec3 3
+%f64mat22 = OpTypeMatrix %f64vec2 2
+
%f32_0 = OpConstant %f32 0
%f32_1 = OpConstant %f32 1
%f32_2 = OpConstant %f32 2
@@ -133,6 +140,13 @@ OpEntryPoint Fragment %main "main"
%f64vec4_0123 = OpConstantComposite %f64vec4 %f64_0 %f64_1 %f64_2 %f64_3
%f64vec4_1234 = OpConstantComposite %f64vec4 %f64_1 %f64_2 %f64_3 %f64_4
+%f32mat22_1212 = OpConstantComposite %f32mat22 %f32vec2_12 %f32vec2_12
+%f32mat23_121212 = OpConstantComposite %f32mat23 %f32vec2_12 %f32vec2_12 %f32vec2_12
+%f32mat32_123123 = OpConstantComposite %f32mat32 %f32vec3_123 %f32vec3_123
+%f32mat33_123123123 = OpConstantComposite %f32mat33 %f32vec3_123 %f32vec3_123 %f32vec3_123
+
+%f64mat22_1212 = OpConstantComposite %f64mat22 %f64vec2_12 %f64vec2_12
+
%main = OpFunction %void None %func
%main_entry = OpLabel)";
@@ -536,4 +550,554 @@ TEST_F(ValidateArithmetics, UDivWrongOperand2) {
"UDiv operand index 3"));
}
+TEST_F(ValidateArithmetics, DotSuccess) {
+ const std::string body = R"(
+%val = OpDot %f32 %f32vec2_01 %f32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, DotWrongTypeId) {
+ const std::string body = R"(
+%val = OpDot %u32 %u32vec2_01 %u32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float scalar type as type_id: Dot"));
+}
+
+TEST_F(ValidateArithmetics, DotNotVectorTypeOperand1) {
+ const std::string body = R"(
+%val = OpDot %f32 %f32 %f32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector as operand: Dot operand index 2"));
+}
+
+TEST_F(ValidateArithmetics, DotNotVectorTypeOperand2) {
+ const std::string body = R"(
+%val = OpDot %f32 %f32vec3_012 %f32_1
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector as operand: Dot operand index 3"));
+}
+
+TEST_F(ValidateArithmetics, DotWrongComponentOperand1) {
+ const std::string body = R"(
+%val = OpDot %f64 %f32vec2_01 %f64vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component type to be equal to type_id: Dot operand index 2"));
+}
+
+TEST_F(ValidateArithmetics, DotWrongComponentOperand2) {
+ const std::string body = R"(
+%val = OpDot %f32 %f32vec2_01 %f64vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component type to be equal to type_id: Dot operand index 3"));
+}
+
+TEST_F(ValidateArithmetics, DotDifferentVectorSize) {
+ const std::string body = R"(
+%val = OpDot %f32 %f32vec2_01 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected operands to have the same number of componenets: Dot"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesScalarSuccess) {
+ const std::string body = R"(
+%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f32_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, VectorTimesScalarWrongTypeId) {
+ const std::string body = R"(
+%val = OpVectorTimesScalar %u32vec2 %f32vec2_01 %f32_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector type as type_id: "
+ "VectorTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesScalarWrongVector) {
+ const std::string body = R"(
+%val = OpVectorTimesScalar %f32vec2 %f32vec3_012 %f32_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected vector operand type to be equal to type_id: "
+ "VectorTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesScalarWrongScalar) {
+ const std::string body = R"(
+%val = OpVectorTimesScalar %f32vec2 %f32vec2_01 %f64_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected scalar operand type to be equal to the component "
+ "type of the vector operand: VectorTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesScalarSuccess) {
+ const std::string body = R"(
+%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f32_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesScalarWrongTypeId) {
+ const std::string body = R"(
+%val = OpMatrixTimesScalar %f32vec2 %f32mat22_1212 %f32_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as type_id: "
+ "MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesScalarWrongMatrix) {
+ const std::string body = R"(
+%val = OpMatrixTimesScalar %f32mat22 %f32vec2_01 %f32_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected matrix operand type to be equal to type_id: "
+ "MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesScalarWrongScalar) {
+ const std::string body = R"(
+%val = OpMatrixTimesScalar %f32mat22 %f32mat22_1212 %f64_2
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected scalar operand type to be equal to the component "
+ "type of the matrix operand: MatrixTimesScalar"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrix2x22Success) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrix3x32Success) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f32vec3_123 %f32mat32_123123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrixWrongTypeId) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector type as type_id: "
+ "VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrixNotFloatVector) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %u32vec2_12 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector type as left operand: "
+ "VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrixWrongVectorComponent) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f64vec2_12 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component types of type_id and vector to be equal: "
+ "VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrix) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as right operand: "
+ "VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrixWrongMatrixComponent) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f64mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component types of type_id and matrix to be equal: "
+ "VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrix2eq2x23Fail) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat23_121212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected number of columns of the matrix to be equal to the type_id "
+ "vector size: VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, VectorTimesMatrix2x32Fail) {
+ const std::string body = R"(
+%val = OpVectorTimesMatrix %f32vec2 %f32vec2_12 %f32mat32_123123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected number of rows of the matrix to be equal to the vector "
+ "operand size: VectorTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVector22x2Success) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVector23x3Success) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec2 %f32mat23_121212 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVectorWrongTypeId) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32mat22 %f32mat22_1212 %f32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector type as type_id: "
+ "MatrixTimesVector"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrix) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec3 %f32vec3_123 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as left operand: "
+ "MatrixTimesVector"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVectorWrongMatrixCol) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec3 %f32mat23_121212 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected column type of the matrix to be equal to type_id: "
+ "MatrixTimesVector"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVectorWrongVector) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %u32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector type as right operand: "
+ "MatrixTimesVector"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVectorDifferentComponents) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f64vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component types of the operands to be equal: "
+ "MatrixTimesVector"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesVector22x3Fail) {
+ const std::string body = R"(
+%val = OpMatrixTimesVector %f32vec2 %f32mat22_1212 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected number of columns of the matrix to be equal to the vector "
+ "size: MatrixTimesVector"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrix22x22Success) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrix23x32Success) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat32_123123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrix33x33Success) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat33 %f32mat33_123123123 %f32mat33_123123123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongTypeId) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32vec2 %f32mat22_1212 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as type_id: MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongLeftOperand) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32vec2_12 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as left operand: MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrixWrongRightOperand) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f32vec2_12
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as right operand: MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrix32x23Fail) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat32_123123 %f32mat23_121212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected column types of type_id and left matrix to be equal: "
+ "MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrixDifferentComponents) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat22_1212 %f64mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component types of type_id and right matrix to be equal: "
+ "MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrix23x23Fail) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat23_121212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected number of columns of type_id and right matrix to be equal: "
+ "MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, MatrixTimesMatrix23x22Fail) {
+ const std::string body = R"(
+%val = OpMatrixTimesMatrix %f32mat22 %f32mat23_121212 %f32mat22_1212
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected number of columns of left matrix and number of rows of right "
+ "matrix to be equal: MatrixTimesMatrix"));
+}
+
+TEST_F(ValidateArithmetics, OuterProduct2x2Success) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec2_01
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, OuterProduct3x2Success) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat32 %f32vec3_123 %f32vec2_01
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, OuterProduct2x3Success) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat23 %f32vec2_01 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateArithmetics, OuterProductWrongTypeId) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32vec2 %f32vec2_01 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float matrix type as type_id: "
+ "OuterProduct"));
+}
+
+TEST_F(ValidateArithmetics, OuterProductWrongLeftOperand) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat22 %f32vec3_123 %f32vec2_01
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected column type of the type_id to be equal to the type "
+ "of the left operand: OuterProduct"));
+}
+
+TEST_F(ValidateArithmetics, OuterProductRightOperandNotFloatVector) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat22 %f32vec2_12 %u32vec2_01
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected float vector type as right operand: OuterProduct"));
+}
+
+TEST_F(ValidateArithmetics, OuterProductRightOperandWrongComponent) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat22 %f32vec2_12 %f64vec2_01
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected component types of the operands to be equal: OuterProduct"));
+}
+
+TEST_F(ValidateArithmetics, OuterProductRightOperandWrongDimension) {
+ const std::string body = R"(
+%val = OpOuterProduct %f32mat22 %f32vec2_12 %f32vec3_123
+)";
+
+ CompileSuccessfully(GenerateCode(body).c_str());
+ ASSERT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(), HasSubstr(
+ "Expected number of columns of the matrix to be equal to the "
+ "vector size of the right operand: OuterProduct"));
+}
+
} // anonymous namespace