summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordan sinclair <dj2@everburning.com>2018-08-02 12:01:26 -0400
committerGitHub <noreply@github.com>2018-08-02 12:01:26 -0400
commit1946fb4ddb73ec39297d73b560d973a56af94c3d (patch)
tree3b51376b69fe57084d8269fc930b4b0011eafddf
parentce644d4a2484fe66e53f5b744ebc4d0d5d49e1ca (diff)
Remove ValidateInstructionAndUpdateValidationState (#1784)
This CL changes the stats aggregator to use ValidateBinaryAndKeepValidationState to process the binary. This means we can remove ValidateInstructionAndUpdateValidationState which expects to be able to call ProcessInstruction in the validate anonymous namespace. This decouples the stats aggregator from how validation processes the binary.
-rw-r--r--source/spirv_stats.cpp222
-rw-r--r--source/val/validate.cpp17
-rw-r--r--source/val/validate.h5
-rw-r--r--source/val/validation_state.h18
4 files changed, 103 insertions, 159 deletions
diff --git a/source/spirv_stats.cpp b/source/spirv_stats.cpp
index a9200f37..9720d550 100644
--- a/source/spirv_stats.cpp
+++ b/source/spirv_stats.cpp
@@ -19,9 +19,7 @@
#include <algorithm>
#include <memory>
#include <string>
-#include <vector>
-#include "binary.h"
#include "diagnostic.h"
#include "enum_string_mapping.h"
#include "extensions.h"
@@ -30,8 +28,6 @@
#include "opcode.h"
#include "operand.h"
#include "spirv-tools/libspirv.h"
-#include "spirv_endian.h"
-#include "spirv_validator_options.h"
#include "val/instruction.h"
#include "val/validate.h"
#include "val/validation_state.h"
@@ -44,68 +40,54 @@ namespace {
// instruction.
class StatsAggregator {
public:
- StatsAggregator(SpirvStats* in_out_stats, const spv_const_context context,
- const uint32_t* words, size_t num_words) {
- stats_ = in_out_stats;
- vstate_.reset(new val::ValidationState_t(context, &validator_options_,
- words, num_words));
- }
-
- // Collects header statistics and sets correct id_bound.
- spv_result_t ProcessHeader(spv_endianness_t /* endian */,
- uint32_t /* magic */, uint32_t version,
- uint32_t generator, uint32_t id_bound,
- uint32_t /* schema */) {
- vstate_->setIdBound(id_bound);
- ++stats_->version_hist[version];
- ++stats_->generator_hist[generator];
- return SPV_SUCCESS;
- }
-
- // Runs validator to validate the instruction and update vstate_,
- // then procession the instruction to collect stats.
- spv_result_t ProcessInstruction(const spv_parsed_instruction_t* inst) {
- const spv_result_t validation_result =
- ValidateInstructionAndUpdateValidationState(vstate_.get(), inst);
- if (validation_result != SPV_SUCCESS) return validation_result;
-
- ProcessOpcode();
- ProcessCapability();
- ProcessExtension();
- ProcessConstant();
- ProcessEnums();
- ProcessLiteralStrings();
- ProcessNonIdWords();
- ProcessIdDescriptors();
-
- return SPV_SUCCESS;
+ StatsAggregator(SpirvStats* in_out_stats, const val::ValidationState_t* state)
+ : stats_(in_out_stats), vstate_(state) {}
+
+ // Processes the instructions to collect stats.
+ void aggregate() {
+ const auto& instructions = vstate_->ordered_instructions();
+
+ ++stats_->version_hist[vstate_->version()];
+ ++stats_->generator_hist[vstate_->generator()];
+
+ for (size_t i = 0; i < instructions.size(); ++i) {
+ const auto& inst = instructions[i];
+
+ ProcessOpcode(&inst, i);
+ ProcessCapability(&inst);
+ ProcessExtension(&inst);
+ ProcessConstant(&inst);
+ ProcessEnums(&inst);
+ ProcessLiteralStrings(&inst);
+ ProcessNonIdWords(&inst);
+ ProcessIdDescriptors(&inst);
+ }
}
// Collects statistics of descriptors generated by IdDescriptorCollection.
- void ProcessIdDescriptors() {
- const val::Instruction& inst = GetCurrentInstruction();
+ void ProcessIdDescriptors(const val::Instruction* inst) {
const uint32_t new_descriptor =
- id_descriptors_.ProcessInstruction(inst.c_inst());
+ id_descriptors_.ProcessInstruction(inst->c_inst());
if (new_descriptor) {
std::stringstream ss;
- ss << spvOpcodeString(inst.opcode());
- for (size_t i = 1; i < inst.words().size(); ++i) {
- ss << " " << inst.word(i);
+ ss << spvOpcodeString(inst->opcode());
+ for (size_t i = 1; i < inst->words().size(); ++i) {
+ ss << " " << inst->word(i);
}
stats_->id_descriptor_labels.emplace(new_descriptor, ss.str());
}
uint32_t index = 0;
- for (const auto& operand : inst.operands()) {
+ for (const auto& operand : inst->operands()) {
if (spvIsIdType(operand.type)) {
const uint32_t descriptor =
- id_descriptors_.GetDescriptor(inst.word(operand.offset));
+ id_descriptors_.GetDescriptor(inst->word(operand.offset));
if (descriptor) {
++stats_->id_descriptor_hist[descriptor];
++stats_
->operand_slot_id_descriptor_hist[std::pair<uint32_t, uint32_t>(
- inst.opcode(), index)][descriptor];
+ inst->opcode(), index)][descriptor];
}
}
++index;
@@ -113,9 +95,8 @@ class StatsAggregator {
}
// Collects statistics of enum words for operands of specific types.
- void ProcessEnums() {
- const val::Instruction& inst = GetCurrentInstruction();
- for (const auto& operand : inst.operands()) {
+ void ProcessEnums(const val::Instruction* inst) {
+ for (const auto& operand : inst->operands()) {
switch (operand.type) {
case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
case SPV_OPERAND_TYPE_EXECUTION_MODEL:
@@ -139,7 +120,7 @@ class StatsAggregator {
case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
case SPV_OPERAND_TYPE_CAPABILITY: {
- ++stats_->enum_hist[operand.type][inst.word(operand.offset)];
+ ++stats_->enum_hist[operand.type][inst->word(operand.offset)];
break;
}
default:
@@ -149,79 +130,74 @@ class StatsAggregator {
}
// Collects statistics of literal strings used by opcodes.
- void ProcessLiteralStrings() {
- const val::Instruction& inst = GetCurrentInstruction();
- for (const auto& operand : inst.operands()) {
+ void ProcessLiteralStrings(const val::Instruction* inst) {
+ for (const auto& operand : inst->operands()) {
if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
const std::string str =
- reinterpret_cast<const char*>(&inst.words()[operand.offset]);
- ++stats_->literal_strings_hist[inst.opcode()][str];
+ reinterpret_cast<const char*>(&inst->words()[operand.offset]);
+ ++stats_->literal_strings_hist[inst->opcode()][str];
}
}
}
// Collects statistics of all single word non-id operand slots.
- void ProcessNonIdWords() {
- const val::Instruction& inst = GetCurrentInstruction();
+ void ProcessNonIdWords(const val::Instruction* inst) {
uint32_t index = 0;
- for (const auto& operand : inst.operands()) {
+ for (const auto& operand : inst->operands()) {
if (operand.num_words == 1 && !spvIsIdType(operand.type)) {
++stats_->operand_slot_non_id_words_hist[std::pair<uint32_t, uint32_t>(
- inst.opcode(), index)][inst.word(operand.offset)];
+ inst->opcode(), index)][inst->word(operand.offset)];
}
++index;
}
}
// Collects OpCapability statistics.
- void ProcessCapability() {
- const val::Instruction& inst = GetCurrentInstruction();
- if (inst.opcode() != SpvOpCapability) return;
- const uint32_t capability = inst.word(inst.operands()[0].offset);
+ void ProcessCapability(const val::Instruction* inst) {
+ if (inst->opcode() != SpvOpCapability) return;
+ const uint32_t capability = inst->word(inst->operands()[0].offset);
++stats_->capability_hist[capability];
}
// Collects OpExtension statistics.
- void ProcessExtension() {
- const val::Instruction& inst = GetCurrentInstruction();
- if (inst.opcode() != SpvOpExtension) return;
- const std::string extension = GetExtensionString(&inst.c_inst());
+ void ProcessExtension(const val::Instruction* inst) {
+ if (inst->opcode() != SpvOpExtension) return;
+ const std::string extension = GetExtensionString(&inst->c_inst());
++stats_->extension_hist[extension];
}
// Collects OpCode statistics.
- void ProcessOpcode() {
- auto inst_it = vstate_->ordered_instructions().rbegin();
- const SpvOp opcode = inst_it->opcode();
+ void ProcessOpcode(const val::Instruction* inst, size_t idx) {
+ const SpvOp opcode = inst->opcode();
++stats_->opcode_hist[opcode];
const uint32_t opcode_and_num_operands =
- (uint32_t(inst_it->operands().size()) << 16) | uint32_t(opcode);
+ (uint32_t(inst->operands().size()) << 16) | uint32_t(opcode);
++stats_->opcode_and_num_operands_hist[opcode_and_num_operands];
- ++inst_it;
+ if (idx == 0) return;
- if (inst_it != vstate_->ordered_instructions().rend()) {
- const SpvOp prev_opcode = inst_it->opcode();
- ++stats_->opcode_and_num_operands_markov_hist[prev_opcode]
- [opcode_and_num_operands];
- }
+ --idx;
+
+ const auto& instructions = vstate_->ordered_instructions();
+ const SpvOp prev_opcode = instructions[idx].opcode();
+ ++stats_->opcode_and_num_operands_markov_hist[prev_opcode]
+ [opcode_and_num_operands];
auto step_it = stats_->opcode_markov_hist.begin();
- for (; inst_it != vstate_->ordered_instructions().rend() &&
- step_it != stats_->opcode_markov_hist.end();
- ++inst_it, ++step_it) {
- auto& hist = (*step_it)[inst_it->opcode()];
+ for (; step_it != stats_->opcode_markov_hist.end(); --idx, ++step_it) {
+ auto& hist = (*step_it)[instructions[idx].opcode()];
++hist[opcode];
+
+ if (idx == 0) break;
}
}
// Collects OpConstant statistics.
- void ProcessConstant() {
- const val::Instruction& inst = GetCurrentInstruction();
- if (inst.opcode() != SpvOpConstant) return;
+ void ProcessConstant(const val::Instruction* inst) {
+ if (inst->opcode() != SpvOpConstant) return;
- const uint32_t type_id = inst.GetOperandAs<uint32_t>(0);
+ const uint32_t type_id = inst->GetOperandAs<uint32_t>(0);
const auto type_decl_it = vstate_->all_definitions().find(type_id);
assert(type_decl_it != vstate_->all_definitions().end());
@@ -233,90 +209,54 @@ class StatsAggregator {
assert(is_signed == 0 || is_signed == 1);
if (bit_width == 16) {
if (is_signed)
- ++stats_->s16_constant_hist[inst.GetOperandAs<int16_t>(2)];
+ ++stats_->s16_constant_hist[inst->GetOperandAs<int16_t>(2)];
else
- ++stats_->u16_constant_hist[inst.GetOperandAs<uint16_t>(2)];
+ ++stats_->u16_constant_hist[inst->GetOperandAs<uint16_t>(2)];
} else if (bit_width == 32) {
if (is_signed)
- ++stats_->s32_constant_hist[inst.GetOperandAs<int32_t>(2)];
+ ++stats_->s32_constant_hist[inst->GetOperandAs<int32_t>(2)];
else
- ++stats_->u32_constant_hist[inst.GetOperandAs<uint32_t>(2)];
+ ++stats_->u32_constant_hist[inst->GetOperandAs<uint32_t>(2)];
} else if (bit_width == 64) {
if (is_signed)
- ++stats_->s64_constant_hist[inst.GetOperandAs<int64_t>(2)];
+ ++stats_->s64_constant_hist[inst->GetOperandAs<int64_t>(2)];
else
- ++stats_->u64_constant_hist[inst.GetOperandAs<uint64_t>(2)];
+ ++stats_->u64_constant_hist[inst->GetOperandAs<uint64_t>(2)];
} else {
assert(false && "TypeInt bit width is not 16, 32 or 64");
}
} else if (type_op == SpvOpTypeFloat) {
const uint32_t bit_width = type_decl_inst.GetOperandAs<uint32_t>(1);
if (bit_width == 32) {
- ++stats_->f32_constant_hist[inst.GetOperandAs<float>(2)];
+ ++stats_->f32_constant_hist[inst->GetOperandAs<float>(2)];
} else if (bit_width == 64) {
- ++stats_->f64_constant_hist[inst.GetOperandAs<double>(2)];
+ ++stats_->f64_constant_hist[inst->GetOperandAs<double>(2)];
} else {
assert(bit_width == 16);
}
}
}
- SpirvStats* stats() { return stats_; }
-
private:
- // Returns the current instruction (the one last processed by the validator).
- const val::Instruction& GetCurrentInstruction() const {
- return vstate_->ordered_instructions().back();
- }
-
SpirvStats* stats_;
- spv_validator_options_t validator_options_;
- std::unique_ptr<val::ValidationState_t> vstate_;
+ const val::ValidationState_t* vstate_;
IdDescriptorCollection id_descriptors_;
};
-spv_result_t ProcessHeader(void* user_data, spv_endianness_t endian,
- uint32_t magic, uint32_t version, uint32_t generator,
- uint32_t id_bound, uint32_t schema) {
- StatsAggregator* stats_aggregator =
- reinterpret_cast<StatsAggregator*>(user_data);
- return stats_aggregator->ProcessHeader(endian, magic, version, generator,
- id_bound, schema);
-}
-
-spv_result_t ProcessInstruction(void* user_data,
- const spv_parsed_instruction_t* inst) {
- StatsAggregator* stats_aggregator =
- reinterpret_cast<StatsAggregator*>(user_data);
- return stats_aggregator->ProcessInstruction(inst);
-}
-
} // namespace
spv_result_t AggregateStats(const spv_context_t& context, const uint32_t* words,
const size_t num_words, spv_diagnostic* pDiagnostic,
SpirvStats* stats) {
- spv_const_binary_t binary = {words, num_words};
-
- spv_endianness_t endian;
- spv_position_t position = {};
- if (spvBinaryEndianness(&binary, &endian)) {
- return DiagnosticStream(position, context.consumer, "",
- SPV_ERROR_INVALID_BINARY)
- << "Invalid SPIR-V magic number.";
- }
-
- spv_header_t header;
- if (spvBinaryHeaderGet(&binary, endian, &header)) {
- return DiagnosticStream(position, context.consumer, "",
- SPV_ERROR_INVALID_BINARY)
- << "Invalid SPIR-V header.";
- }
-
- StatsAggregator stats_aggregator(stats, &context, words, num_words);
-
- return spvBinaryParse(&context, &stats_aggregator, words, num_words,
- ProcessHeader, ProcessInstruction, pDiagnostic);
+ std::unique_ptr<val::ValidationState_t> vstate;
+ spv_validator_options_t options;
+ spv_result_t result = ValidateBinaryAndKeepValidationState(
+ &context, &options, words, num_words, pDiagnostic, &vstate);
+ if (result != SPV_SUCCESS) return result;
+
+ StatsAggregator stats_aggregator(stats, vstate.get());
+ stats_aggregator.aggregate();
+ return SPV_SUCCESS;
}
} // namespace spvtools
diff --git a/source/val/validate.cpp b/source/val/validate.cpp
index a56019cb..b1587c2b 100644
--- a/source/val/validate.cpp
+++ b/source/val/validate.cpp
@@ -59,19 +59,15 @@ spv_result_t spvValidateIDs(const spv_instruction_t* pInsts,
// TODO(umar): Validate header
// TODO(umar): The binary parser validates the magic word, and the length of the
// header, but nothing else.
-spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic,
+spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
uint32_t version, uint32_t generator, uint32_t id_bound,
- uint32_t reserved) {
+ uint32_t) {
// Record the ID bound so that the validator can ensure no ID is out of bound.
ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
_.setIdBound(id_bound);
+ _.setGenerator(generator);
+ _.setVersion(version);
- (void)endian;
- (void)magic;
- (void)version;
- (void)generator;
- (void)id_bound;
- (void)reserved;
return SPV_SUCCESS;
}
@@ -354,11 +350,6 @@ spv_result_t ValidateBinaryAndKeepValidationState(
hijack_context, words, num_words, pDiagnostic, vstate->get());
}
-spv_result_t ValidateInstructionAndUpdateValidationState(
- ValidationState_t* vstate, const spv_parsed_instruction_t* inst) {
- return ProcessInstruction(vstate, inst);
-}
-
} // namespace val
} // namespace spvtools
diff --git a/source/val/validate.h b/source/val/validate.h
index 654e87f3..45709c9c 100644
--- a/source/val/validate.h
+++ b/source/val/validate.h
@@ -219,11 +219,6 @@ spv_result_t ValidateBinaryAndKeepValidationState(
const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
std::unique_ptr<ValidationState_t>* vstate);
-// Performs validation for a single instruction and updates given validation
-// state.
-spv_result_t ValidateInstructionAndUpdateValidationState(
- ValidationState_t* vstate, const spv_parsed_instruction_t* inst);
-
} // namespace val
} // namespace spvtools
diff --git a/source/val/validation_state.h b/source/val/validation_state.h
index 72c563bf..3f63047f 100644
--- a/source/val/validation_state.h
+++ b/source/val/validation_state.h
@@ -101,6 +101,18 @@ class ValidationState_t {
/// Returns the command line options
spv_const_validator_options options() const { return options_; }
+ /// Sets the ID of the generator for this module.
+ void setGenerator(uint32_t gen) { generator_ = gen; }
+
+ /// Returns the ID of the generator for this module.
+ uint32_t generator() const { return generator_; }
+
+ /// Sets the SPIR-V version of this module.
+ void setVersion(uint32_t ver) { version_ = ver; }
+
+ /// Gets the SPIR-V version of this module.
+ uint32_t version() const { return version_; }
+
/// Forward declares the id in the module
spv_result_t ForwardDeclareId(uint32_t id);
@@ -523,6 +535,12 @@ class ValidationState_t {
const uint32_t* words_;
const size_t num_words_;
+ /// The generator of the SPIR-V.
+ uint32_t generator_ = 0;
+
+ /// The version of the SPIR-V.
+ uint32_t version_ = 0;
+
/// The total number of instructions in the binary.
size_t total_instructions_ = 0;
/// The total number of functions in the binary.