summaryrefslogtreecommitdiff
path: root/source/validate_arithmetics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/validate_arithmetics.cpp')
-rw-r--r--source/validate_arithmetics.cpp281
1 files changed, 281 insertions, 0 deletions
diff --git a/source/validate_arithmetics.cpp b/source/validate_arithmetics.cpp
index f3f83186..3ac5c222 100644
--- a/source/validate_arithmetics.cpp
+++ b/source/validate_arithmetics.cpp
@@ -126,6 +126,287 @@ spv_result_t ArithmeticsPass(ValidationState_t& _,
break;
}
+ case SpvOpDot: {
+ if (!_.IsFloatScalarType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float scalar type as type_id: "
+ << spvOpcodeString(opcode);
+
+ uint32_t first_vector_num_components = 0;
+
+ for (size_t operand_index = 2; operand_index < inst->num_operands;
+ ++operand_index) {
+ const uint32_t type_id =
+ _.GetTypeId(GetOperandWord(inst, operand_index));
+
+ if (!type_id || !_.IsFloatVectorType(type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector as operand: "
+ << spvOpcodeString(opcode) << " operand index " << operand_index;
+
+
+ const uint32_t component_type = _.GetComponentType(type_id);
+ if (component_type != inst->type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component type to be equal to type_id: "
+ << spvOpcodeString(opcode) << " operand index " << operand_index;
+
+ const uint32_t num_components = _.GetDimension(type_id);
+ if (operand_index == 2) {
+ first_vector_num_components = num_components;
+ } else if (num_components != first_vector_num_components) {
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected operands to have the same number of componenets: "
+ << spvOpcodeString(opcode);
+ }
+ }
+ break;
+ }
+
+ case SpvOpVectorTimesScalar: {
+ if (!_.IsFloatVectorType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ if (inst->type_id != vector_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected vector operand type to be equal to type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t component_type = _.GetComponentType(vector_type_id);
+
+ const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+ if (component_type != scalar_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected scalar operand type to be equal to the component "
+ << "type of the vector operand: "
+ << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpMatrixTimesScalar: {
+ if (!_.IsFloatMatrixType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ if (inst->type_id != matrix_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected matrix operand type to be equal to type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t component_type = _.GetComponentType(matrix_type_id);
+
+ const uint32_t scalar_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+ if (component_type != scalar_type_id)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected scalar operand type to be equal to the component "
+ << "type of the matrix operand: "
+ << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpVectorTimesMatrix: {
+ const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ if (!_.IsFloatVectorType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as type_id: "
+ << spvOpcodeString(opcode);
+
+ const uint32_t res_component_type = _.GetComponentType(inst->type_id);
+
+ if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as left operand: "
+ << spvOpcodeString(opcode);
+
+ if (res_component_type != _.GetComponentType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of type_id and vector to be equal: "
+ << spvOpcodeString(opcode);
+
+ uint32_t matrix_num_rows = 0;
+ uint32_t matrix_num_cols = 0;
+ uint32_t matrix_col_type = 0;
+ uint32_t matrix_component_type = 0;
+ if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
+ &matrix_num_cols, &matrix_col_type,
+ &matrix_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (res_component_type != matrix_component_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of type_id and matrix to be equal: "
+ << spvOpcodeString(opcode);
+
+ if (matrix_num_cols != _.GetDimension(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of the matrix to be equal to the "
+ << "type_id vector size: " << spvOpcodeString(opcode);
+
+ if (matrix_num_rows != _.GetDimension(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of rows of the matrix to be equal to the "
+ << "vector operand size: " << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpMatrixTimesVector: {
+ const uint32_t matrix_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t vector_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ if (!_.IsFloatVectorType(inst->type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as type_id: "
+ << spvOpcodeString(opcode);
+
+ uint32_t matrix_num_rows = 0;
+ uint32_t matrix_num_cols = 0;
+ uint32_t matrix_col_type = 0;
+ uint32_t matrix_component_type = 0;
+ if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
+ &matrix_num_cols, &matrix_col_type,
+ &matrix_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as left operand: "
+ << spvOpcodeString(opcode);
+
+ if (inst->type_id != matrix_col_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected column type of the matrix to be equal to type_id: "
+ << spvOpcodeString(opcode);
+
+ if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (matrix_component_type != _.GetComponentType(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of the operands to be equal: "
+ << spvOpcodeString(opcode);
+
+ if (matrix_num_cols != _.GetDimension(vector_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of the matrix to be equal to the "
+ << "vector size: " << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ case SpvOpMatrixTimesMatrix: {
+ const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ uint32_t res_num_rows = 0;
+ uint32_t res_num_cols = 0;
+ uint32_t res_col_type = 0;
+ uint32_t res_component_type = 0;
+ if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
+ &res_col_type, &res_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ uint32_t left_num_rows = 0;
+ uint32_t left_num_cols = 0;
+ uint32_t left_col_type = 0;
+ uint32_t left_component_type = 0;
+ if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
+ &left_col_type, &left_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as left operand: "
+ << spvOpcodeString(opcode);
+
+ uint32_t right_num_rows = 0;
+ uint32_t right_num_cols = 0;
+ uint32_t right_col_type = 0;
+ uint32_t right_component_type = 0;
+ if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
+ &right_col_type, &right_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (!_.IsFloatScalarType(res_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ if (res_col_type != left_col_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected column types of type_id and left matrix to be "
+ << "equal: " << spvOpcodeString(opcode);
+
+ if (res_component_type != right_component_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of type_id and right matrix to be "
+ << "equal: " << spvOpcodeString(opcode);
+
+ if (res_num_cols != right_num_cols)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of type_id and right matrix to be "
+ << "equal: " << spvOpcodeString(opcode);
+
+ if (left_num_cols != right_num_rows)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of left matrix and number of rows "
+ << "of right matrix to be equal: " << spvOpcodeString(opcode);
+
+ assert(left_num_rows == res_num_rows);
+ break;
+ }
+
+ case SpvOpOuterProduct: {
+ const uint32_t left_type_id = _.GetTypeId(GetOperandWord(inst, 2));
+ const uint32_t right_type_id = _.GetTypeId(GetOperandWord(inst, 3));
+
+ uint32_t res_num_rows = 0;
+ uint32_t res_num_cols = 0;
+ uint32_t res_col_type = 0;
+ uint32_t res_component_type = 0;
+ if (!_.GetMatrixTypeInfo(inst->type_id, &res_num_rows, &res_num_cols,
+ &res_col_type, &res_component_type))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float matrix type as type_id: "
+ << spvOpcodeString(opcode);
+
+ if (left_type_id != res_col_type)
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected column type of the type_id to be equal to the type "
+ << "of the left operand: "
+ << spvOpcodeString(opcode);
+
+ if (!right_type_id || !_.IsFloatVectorType(right_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected float vector type as right operand: "
+ << spvOpcodeString(opcode);
+
+ if (res_component_type != _.GetComponentType(right_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected component types of the operands to be equal: "
+ << spvOpcodeString(opcode);
+
+ if (res_num_cols != _.GetDimension(right_type_id))
+ return _.diag(SPV_ERROR_INVALID_DATA)
+ << "Expected number of columns of the matrix to be equal to the "
+ << "vector size of the right operand: " << spvOpcodeString(opcode);
+
+ break;
+ }
+
+ // TODO(atgoo@github.com): Support other operations.
+
default:
break;
}