summaryrefslogtreecommitdiff
path: root/source/comp
diff options
context:
space:
mode:
authorAndrey Tuganov <andreyt@google.com>2017-05-31 13:07:51 -0400
committerDavid Neto <dneto@google.com>2017-06-30 12:22:48 -0400
commit73e8dac5b925f979d5c9e94770fb104fcc539ac2 (patch)
treef8d3ed449e0bdf8e43c4b6a8ff2f03ce49c3c0f6 /source/comp
parent8d3882a40807f0df5d7b5c09239f47f6c209eb46 (diff)
Added compression tool tools/spirv-markv. Work in progress.
Command line application is located at tools/spirv-markv API at include/spirv-tools/markv.h At the moment only very basic compression is implemented, mostly varint. Scope of supported SPIR-V opcodes is also limited. Using a simple move-to-front implementation instead of encoding mapped ids. Work in progress: - Does not cover all of SPIR-V - Does not promise compatibility of compression/decompression across different versions of the code.
Diffstat (limited to 'source/comp')
-rw-r--r--source/comp/CMakeLists.txt32
-rw-r--r--source/comp/markv_codec.cpp1518
2 files changed, 1550 insertions, 0 deletions
diff --git a/source/comp/CMakeLists.txt b/source/comp/CMakeLists.txt
new file mode 100644
index 00000000..1cf312fb
--- /dev/null
+++ b/source/comp/CMakeLists.txt
@@ -0,0 +1,32 @@
+# Copyright (c) 2017 Google Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_library(SPIRV-Tools-comp markv_codec.cpp)
+
+spvtools_default_compile_options(SPIRV-Tools-comp)
+target_include_directories(SPIRV-Tools-comp
+ PUBLIC ${spirv-tools_SOURCE_DIR}/include
+ PUBLIC ${SPIRV_HEADER_INCLUDE_DIR}
+ PRIVATE ${spirv-tools_BINARY_DIR}
+)
+
+target_link_libraries(SPIRV-Tools-comp
+ PUBLIC ${SPIRV_TOOLS})
+
+set_property(TARGET SPIRV-Tools-comp PROPERTY FOLDER "SPIRV-Tools libraries")
+
+install(TARGETS SPIRV-Tools-comp
+ RUNTIME DESTINATION bin
+ LIBRARY DESTINATION lib
+ ARCHIVE DESTINATION lib)
diff --git a/source/comp/markv_codec.cpp b/source/comp/markv_codec.cpp
new file mode 100644
index 00000000..f621d3b3
--- /dev/null
+++ b/source/comp/markv_codec.cpp
@@ -0,0 +1,1518 @@
+// Copyright (c) 2017 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Contains
+// - SPIR-V to MARK-V encoder
+// - MARK-V to SPIR-V decoder
+//
+// MARK-V is a compression format for SPIR-V binaries. It strips away
+// non-essential information (such as result ids which can be regenerated) and
+// uses various bit reduction techiniques to reduce the size of the binary.
+//
+// MarkvModel is a flatbuffers object containing a set of rules defining how
+// compression/decompression is done (coding schemes, dictionaries).
+
+#include <algorithm>
+#include <cassert>
+#include <cstring>
+#include <functional>
+#include <iostream>
+#include <list>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <vector>
+
+#include "binary.h"
+#include "diagnostic.h"
+#include "enum_string_mapping.h"
+#include "extensions.h"
+#include "instruction.h"
+#include "opcode.h"
+#include "operand.h"
+#include "spirv-tools/libspirv.h"
+#include "spirv-tools/markv.h"
+#include "spirv_endian.h"
+#include "spirv_validator_options.h"
+#include "util/bit_stream.h"
+#include "util/parse_number.h"
+#include "validate.h"
+#include "val/instruction.h"
+#include "val/validation_state.h"
+
+using libspirv::Instruction;
+using libspirv::ValidationState_t;
+using spvtools::ValidateInstructionAndUpdateValidationState;
+using spvutils::BitReaderWord64;
+using spvutils::BitWriterWord64;
+
+struct spv_markv_encoder_options_t {
+};
+
+struct spv_markv_decoder_options_t {
+};
+
+namespace {
+
+const uint32_t kSpirvMagicNumber = SpvMagicNumber;
+const uint32_t kMarkvMagicNumber = 0x07230303;
+
+enum {
+ kMarkvFirstOpcode = 65536,
+ kMarkvOpNextInstructionEncodesResultId = 65536,
+};
+
+const size_t kCommentNumWhitespaces = 2;
+
+// TODO(atgoo@github.com): This is a placeholder for an autogenerated flatbuffer
+// containing MARK-V model for a specific dataset.
+class MarkvModel {
+ public:
+ size_t opcode_chunk_length() const { return 7; }
+ size_t num_operands_chunk_length() const { return 3; }
+ size_t id_index_chunk_length() const { return 3; }
+
+ size_t u16_chunk_length() const { return 4; }
+ size_t s16_chunk_length() const { return 4; }
+ size_t s16_block_exponent() const { return 6; }
+
+ size_t u32_chunk_length() const { return 8; }
+ size_t s32_chunk_length() const { return 8; }
+ size_t s32_block_exponent() const { return 10; }
+
+ size_t u64_chunk_length() const { return 8; }
+ size_t s64_chunk_length() const { return 8; }
+ size_t s64_block_exponent() const { return 10; }
+};
+
+const MarkvModel* GetDefaultModel() {
+ static MarkvModel model;
+ return &model;
+}
+
+// Returns chunk length used for variable length encoding of spirv operand
+// words. Returns zero if operand type corresponds to potentially multiple
+// words or a word which is not expected to profit from variable width encoding.
+// Chunk length is selected based on the size of expected value.
+// Most of these values will later be encoded with probability-based coding,
+// but variable width integer coding is a good quick solution.
+// TODO(atgoo@github.com): Put this in MarkvModel flatbuffer.
+size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) {
+ switch (type) {
+ case SPV_OPERAND_TYPE_TYPE_ID:
+ return 4;
+ case SPV_OPERAND_TYPE_RESULT_ID:
+ case SPV_OPERAND_TYPE_ID:
+ case SPV_OPERAND_TYPE_SCOPE_ID:
+ case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID:
+ return 8;
+ case SPV_OPERAND_TYPE_LITERAL_INTEGER:
+ case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER:
+ return 6;
+ case SPV_OPERAND_TYPE_CAPABILITY:
+ return 6;
+ case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
+ case SPV_OPERAND_TYPE_EXECUTION_MODEL:
+ return 3;
+ case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
+ case SPV_OPERAND_TYPE_MEMORY_MODEL:
+ return 2;
+ case SPV_OPERAND_TYPE_EXECUTION_MODE:
+ return 6;
+ case SPV_OPERAND_TYPE_STORAGE_CLASS:
+ return 4;
+ case SPV_OPERAND_TYPE_DIMENSIONALITY:
+ case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
+ return 3;
+ case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
+ return 2;
+ case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
+ return 6;
+ case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
+ case SPV_OPERAND_TYPE_LINKAGE_TYPE:
+ case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
+ return 2;
+ case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
+ return 3;
+ case SPV_OPERAND_TYPE_DECORATION:
+ case SPV_OPERAND_TYPE_BUILT_IN:
+ return 6;
+ case SPV_OPERAND_TYPE_GROUP_OPERATION:
+ case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
+ case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO:
+ return 2;
+ case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
+ case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
+ case SPV_OPERAND_TYPE_LOOP_CONTROL:
+ case SPV_OPERAND_TYPE_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
+ case SPV_OPERAND_TYPE_SELECTION_CONTROL:
+ return 4;
+ default:
+ return 0;
+ }
+ return 0;
+}
+
+// Returns true if the opcode has a fixed number of operands. May return a
+// false negative.
+bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) {
+ switch (opcode) {
+ // TODO(atgoo@github.com) This is not a complete list.
+ case SpvOpNop:
+ case SpvOpName:
+ case SpvOpUndef:
+ case SpvOpSizeOf:
+ case SpvOpLine:
+ case SpvOpNoLine:
+ case SpvOpDecorationGroup:
+ case SpvOpExtension:
+ case SpvOpExtInstImport:
+ case SpvOpMemoryModel:
+ case SpvOpCapability:
+ case SpvOpTypeVoid:
+ case SpvOpTypeBool:
+ case SpvOpTypeInt:
+ case SpvOpTypeFloat:
+ case SpvOpTypeVector:
+ case SpvOpTypeMatrix:
+ case SpvOpTypeSampler:
+ case SpvOpTypeSampledImage:
+ case SpvOpTypeArray:
+ case SpvOpTypePointer:
+ case SpvOpConstantTrue:
+ case SpvOpConstantFalse:
+ case SpvOpLabel:
+ case SpvOpBranch:
+ case SpvOpFunction:
+ case SpvOpFunctionParameter:
+ case SpvOpFunctionEnd:
+ case SpvOpBitcast:
+ case SpvOpCopyObject:
+ case SpvOpTranspose:
+ case SpvOpSNegate:
+ case SpvOpFNegate:
+ case SpvOpIAdd:
+ case SpvOpFAdd:
+ case SpvOpISub:
+ case SpvOpFSub:
+ case SpvOpIMul:
+ case SpvOpFMul:
+ case SpvOpUDiv:
+ case SpvOpSDiv:
+ case SpvOpFDiv:
+ case SpvOpUMod:
+ case SpvOpSRem:
+ case SpvOpSMod:
+ case SpvOpFRem:
+ case SpvOpFMod:
+ case SpvOpVectorTimesScalar:
+ case SpvOpMatrixTimesScalar:
+ case SpvOpVectorTimesMatrix:
+ case SpvOpMatrixTimesVector:
+ case SpvOpMatrixTimesMatrix:
+ case SpvOpOuterProduct:
+ case SpvOpDot:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+size_t GetNumBitsToNextByte(size_t bit_pos) {
+ return (8 - (bit_pos % 8)) % 8;
+}
+
+bool ShouldByteBreak(size_t bit_pos) {
+ const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos);
+ return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2;
+}
+
+// Defines and returns current MARK-V version.
+uint32_t GetMarkvVersion() {
+ const uint32_t kVersionMajor = 1;
+ const uint32_t kVersionMinor = 0;
+ return kVersionMinor | (kVersionMajor << 16);
+}
+
+class CommentLogger {
+ public:
+ void AppendText(const std::string& str) {
+ Append(str);
+ use_delimiter_ = false;
+ }
+
+ void AppendTextNewLine(const std::string& str) {
+ Append(str);
+ Append("\n");
+ use_delimiter_ = false;
+ }
+
+ void AppendBitSequence(const std::string& str) {
+ if (use_delimiter_)
+ Append("-");
+ Append(str);
+ use_delimiter_ = true;
+ }
+
+ void AppendWhitespaces(size_t num) {
+ Append(std::string(num, ' '));
+ use_delimiter_ = false;
+ }
+
+ void NewLine() {
+ Append("\n");
+ use_delimiter_ = false;
+ }
+
+ std::string GetText() const {
+ return ss_.str();
+ }
+
+ private:
+ void Append(const std::string& str) {
+ ss_ << str;
+ // std::cerr << str;
+ }
+
+ std::stringstream ss_;
+
+ // If true a delimiter will be appended before the next bit sequence.
+ // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0.
+ bool use_delimiter_ = false;
+};
+
+// Creates spv_text object containing text from |str|.
+// The returned value is owned by the caller and needs to be destroyed with
+// spvTextDestroy.
+spv_text CreateSpvText(const std::string& str) {
+ spv_text out = new spv_text_t();
+ assert(out);
+ char* cstr = new char[str.length() + 1];
+ assert(cstr);
+ std::strncpy(cstr, str.c_str(), str.length());
+ cstr[str.length()] = '\0';
+ out->str = cstr;
+ out->length = str.length();
+ return out;
+}
+
+// Base class for MARK-V encoder and decoder. Contains common functionality
+// such as:
+// - Validator connection and validation state.
+// - SPIR-V grammar and helper functions.
+class MarkvCodecBase {
+ public:
+ virtual ~MarkvCodecBase() {
+ spvValidatorOptionsDestroy(validator_options_);
+ }
+
+ MarkvCodecBase() = delete;
+
+ void SetModel(const MarkvModel* model) {
+ model_ = model;
+ }
+
+ protected:
+ struct MarkvHeader {
+ MarkvHeader() {
+ magic_number = kMarkvMagicNumber;
+ markv_version = GetMarkvVersion();
+ markv_model = 0;
+ markv_length_in_bits = 0;
+ spirv_version = 0;
+ spirv_generator = 0;
+ }
+
+ uint32_t magic_number;
+ uint32_t markv_version;
+ // Magic number to identify or verify MarkvModel used for encoding.
+ uint32_t markv_model;
+ uint32_t markv_length_in_bits;
+ uint32_t spirv_version;
+ uint32_t spirv_generator;
+ };
+
+ explicit MarkvCodecBase(spv_const_context context,
+ spv_validator_options validator_options)
+ : validator_options_(validator_options),
+ vstate_(context, validator_options_), grammar_(context),
+ model_(GetDefaultModel()) {}
+
+ // Validates a single instruction and updates validation state of the module.
+ spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) {
+ return ValidateInstructionAndUpdateValidationState(&vstate_, &inst);
+ }
+
+ // Returns the current instruction (the one last processed by the validator).
+ const Instruction& GetCurrentInstruction() const {
+ return vstate_.ordered_instructions().back();
+ }
+
+ spv_validator_options validator_options_;
+ ValidationState_t vstate_;
+ const libspirv::AssemblyGrammar grammar_;
+ MarkvHeader header_;
+ const MarkvModel* model_;
+
+ // Move-to-front list of all ids.
+ // TODO(atgoo@github.com) Consider a better move-to-front implementation.
+ std::list<uint32_t> move_to_front_ids_;
+};
+
+// SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and
+// EncodeInstruction which can be used as callback by spvBinaryParse.
+// Encoded binary is written to an internally maintained bitstream.
+// After the last instruction is encoded, the resulting MARK-V binary can be
+// acquired by calling GetMarkvBinary().
+// The encoder uses SPIR-V validator to keep internal state, therefore
+// SPIR-V binary needs to be able to pass validator checks.
+// CreateCommentsLogger() can be used to enable the encoder to write comments
+// on how encoding was done, which can later be accessed with GetComments().
+class MarkvEncoder : public MarkvCodecBase {
+ public:
+ MarkvEncoder(spv_const_context context,
+ spv_const_markv_encoder_options options)
+ : MarkvCodecBase(context, GetValidatorOptions(options)),
+ options_(options) {
+ (void) options_;
+ }
+
+ // Writes data from SPIR-V header to MARK-V header.
+ spv_result_t EncodeHeader(
+ spv_endianness_t /* endian */, uint32_t /* magic */,
+ uint32_t version, uint32_t generator, uint32_t id_bound,
+ uint32_t /* schema */) {
+ vstate_.setIdBound(id_bound);
+ header_.spirv_version = version;
+ header_.spirv_generator = generator;
+ return SPV_SUCCESS;
+ }
+
+ // Encodes SPIR-V instruction to MARK-V and writes to bit stream.
+ // Operation can fail if the instruction fails to pass the validator or if
+ // the encoder stubmles on something unexpected.
+ spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst);
+
+ // Concatenates MARK-V header and the bit stream with encoded instructions
+ // into a single buffer and returns it as spv_markv_binary. The returned
+ // value is owned by the caller and needs to be destroyed with
+ // spvMarkvBinaryDestroy().
+ spv_markv_binary GetMarkvBinary() {
+ header_.markv_length_in_bits =
+ static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits());
+ const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes();
+
+ spv_markv_binary markv_binary = new spv_markv_binary_t();
+ markv_binary->data = new uint8_t[num_bytes];
+ markv_binary->length = num_bytes;
+ assert(writer_.GetData());
+ std::memcpy(markv_binary->data, &header_, sizeof(header_));
+ std::memcpy(markv_binary->data + sizeof(header_),
+ writer_.GetData(), writer_.GetDataSizeBytes());
+ return markv_binary;
+ }
+
+ // Creates an internal logger which writes comments on the encoding process.
+ // Output can later be accessed with GetComments().
+ void CreateCommentsLogger() {
+ logger_.reset(new CommentLogger());
+ writer_.SetCallback([this](const std::string& str){
+ logger_->AppendBitSequence(str);
+ });
+ }
+
+ // Optionally adds disassembly to the comments.
+ // Disassembly should contain all instructions in the module separated by
+ // \n, and no header.
+ void SetDisassembly(std::string&& disassembly) {
+ disassembly_.reset(new std::stringstream(std::move(disassembly)));
+ }
+
+ // Extracts the next instruction line from the disassembly and logs it.
+ void LogDisassemblyInstruction() {
+ if (logger_ && disassembly_) {
+ std::string line;
+ std::getline(*disassembly_, line, '\n');
+ logger_->AppendTextNewLine(line);
+ }
+ }
+
+ // Extracts the text from the comment logger.
+ std::string GetComments() const {
+ if (!logger_)
+ return "";
+ return logger_->GetText();
+ }
+
+ private:
+ // Creates and returns validator options. Return value owned by the caller.
+ static spv_validator_options GetValidatorOptions(
+ spv_const_markv_encoder_options) {
+ return spvValidatorOptionsCreate();
+ }
+
+ // Writes a single word to bit stream. |type| determines if the word is
+ // encoded and how.
+ void EncodeOperandWord(spv_operand_type_t type, uint32_t word) {
+ const size_t chunk_length =
+ GetOperandVariableWidthChunkLength(type);
+ if (chunk_length) {
+ writer_.WriteVariableWidthU32(word, chunk_length);
+ } else {
+ writer_.WriteUnencoded(word);
+ }
+ }
+
+ // Returns id index and updates move-to-front.
+ // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535
+ // instructions.
+ uint16_t GetIdIndex(uint32_t id) {
+ if (all_known_ids_.count(id)) {
+ uint16_t index = 0;
+ for (auto it = move_to_front_ids_.begin();
+ it != move_to_front_ids_.end(); ++it) {
+ if (*it == id) {
+ if (index != 0) {
+ move_to_front_ids_.erase(it);
+ move_to_front_ids_.push_front(id);
+ }
+ return index;
+ }
+ ++index;
+ }
+ assert(0 && "Id not found in move_to_front_ids_");
+ return 0;
+ } else {
+ all_known_ids_.insert(id);
+ move_to_front_ids_.push_front(id);
+ return static_cast<uint16_t>(move_to_front_ids_.size() - 1);
+ }
+ }
+
+ void AddByteBreakIfAgreed() {
+ if (!ShouldByteBreak(writer_.GetNumBits()))
+ return;
+
+ if (logger_) {
+ logger_->AppendWhitespaces(kCommentNumWhitespaces);
+ logger_->AppendText("ByteBreak:");
+ }
+
+ writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits()));
+ }
+
+ // Encodes a literal number operand and writes it to the bit stream.
+ void EncodeLiteralNumber(const Instruction& instruction,
+ const spv_parsed_operand_t& operand);
+
+ spv_const_markv_encoder_options options_;
+
+ // Bit stream where encoded instructions are written.
+ BitWriterWord64 writer_;
+
+ // If not nullptr, encoder will write comments.
+ std::unique_ptr<CommentLogger> logger_;
+
+ // If not nullptr, disassembled instruction lines will be written to comments.
+ // Format: \n separated instruction lines, no header.
+ std::unique_ptr<std::stringstream> disassembly_;
+
+ // All ids which were previosly encountered in the module.
+ std::unordered_set<uint32_t> all_known_ids_;
+};
+
+// Decodes MARK-V buffers written by MarkvEncoder.
+class MarkvDecoder : public MarkvCodecBase {
+ public:
+ MarkvDecoder(spv_const_context context,
+ const uint8_t* markv_data,
+ size_t markv_size_bytes,
+ spv_const_markv_decoder_options options)
+ : MarkvCodecBase(context, GetValidatorOptions(options)),
+ options_(options), reader_(markv_data, markv_size_bytes) {
+ (void) options_;
+ vstate_.setIdBound(1);
+ parsed_operands_.reserve(25);
+ }
+
+ // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
+ // Can be called only once. Fails if data of wrong format or ends prematurely,
+ // of if validation fails.
+ spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
+
+ private:
+ // Describes the format of a typed literal number.
+ struct NumberType {
+ spv_number_kind_t type;
+ uint32_t bit_width;
+ };
+
+ // Creates and returns validator options. Return value owned by the caller.
+ static spv_validator_options GetValidatorOptions(
+ spv_const_markv_decoder_options) {
+ return spvValidatorOptionsCreate();
+ }
+
+ // Reads a single word from bit stream. |type| determines if the word needs
+ // to be decoded and how. Returns false if read fails.
+ bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) {
+ const size_t chunk_length = GetOperandVariableWidthChunkLength(type);
+ if (chunk_length) {
+ return reader_.ReadVariableWidthU32(word, chunk_length);
+ } else {
+ return reader_.ReadUnencoded(word);
+ }
+ }
+
+ // Fetches the id from the move-to-front list and moves it to front.
+ uint32_t GetIdAndMoveToFront(uint16_t index) {
+ if (index >= move_to_front_ids_.size()) {
+ // Issue new id.
+ const uint32_t id = vstate_.getIdBound();
+ move_to_front_ids_.push_front(id);
+ vstate_.setIdBound(id + 1);
+ return id;
+ } else {
+ if (index == 0)
+ return move_to_front_ids_.front();
+
+ // Iterate to index.
+ auto it = move_to_front_ids_.begin();
+ for (size_t i = 0; i < index; ++i)
+ ++it;
+ const uint32_t id = *it;
+ move_to_front_ids_.erase(it);
+ move_to_front_ids_.push_front(id);
+ return id;
+ }
+ }
+
+ // Decodes id index and fetches the id from move-to-front list.
+ bool DecodeId(uint32_t* id) {
+ uint16_t index = 0;
+ if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length()))
+ return false;
+
+ *id = GetIdAndMoveToFront(index);
+ return true;
+ }
+
+ bool ReadToByteBreakIfAgreed() {
+ if (!ShouldByteBreak(reader_.GetNumReadBits()))
+ return true;
+
+ uint64_t bits = 0;
+ if (!reader_.ReadBits(&bits,
+ GetNumBitsToNextByte(reader_.GetNumReadBits())))
+ return false;
+
+ if (bits != 0)
+ return false;
+
+ return true;
+ }
+
+ // Reads a literal number as it is described in |operand| from the bit stream,
+ // decodes and writes it to spirv_.
+ spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
+
+ // Reads instruction from bit stream, decodes and validates it.
+ // Decoded instruction is valid until the next call of DecodeInstruction().
+ spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst);
+
+ // Read operand from the stream decodes and validates it.
+ spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset,
+ spv_parsed_instruction_t* inst,
+ const spv_operand_type_t type,
+ spv_operand_pattern_t* expected_operands,
+ bool read_result_id);
+
+ // Records the numeric type for an operand according to the type information
+ // associated with the given non-zero type Id. This can fail if the type Id
+ // is not a type Id, or if the type Id does not reference a scalar numeric
+ // type. On success, return SPV_SUCCESS and populates the num_words,
+ // number_kind, and number_bit_width fields of parsed_operand.
+ spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
+ uint32_t type_id);
+
+ // Records the number type for the given instruction, if that
+ // instruction generates a type. For types that aren't scalar numbers,
+ // record something with number kind SPV_NUMBER_NONE.
+ void RecordNumberType(const spv_parsed_instruction_t& inst);
+
+ spv_const_markv_decoder_options options_;
+
+ // Temporary sink where decoded SPIR-V words are written. Once it contains the
+ // entire module, the container is moved and returned.
+ std::vector<uint32_t> spirv_;
+
+ // Bit stream containing encoded data.
+ BitReaderWord64 reader_;
+
+ // Temporary storage for operands of the currently parsed instruction.
+ // Valid until next DecodeInstruction call.
+ std::vector<spv_parsed_operand_t> parsed_operands_;
+
+ // Maps a result ID to its type ID. By convention:
+ // - a result ID that is a type definition maps to itself.
+ // - a result ID without a type maps to 0. (E.g. for OpLabel)
+ std::unordered_map<uint32_t, uint32_t> id_to_type_id_;
+ // Maps a type ID to its number type description.
+ std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
+};
+
+void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction,
+ const spv_parsed_operand_t& operand) {
+ if (operand.number_bit_width == 32) {
+ const uint32_t word = instruction.word(operand.offset);
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ writer_.WriteVariableWidthU32(word, model_->u32_chunk_length());
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int32_t val = 0;
+ std::memcpy(&val, &word, 4);
+ writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(),
+ model_->s32_block_exponent());
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ writer_.WriteUnencoded(word);
+ } else {
+ assert(0);
+ }
+ } else if (operand.number_bit_width == 16) {
+ const uint16_t word =
+ static_cast<uint16_t>(instruction.word(operand.offset));
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ writer_.WriteVariableWidthU16(word, model_->u16_chunk_length());
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int16_t val = 0;
+ std::memcpy(&val, &word, 2);
+ writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(),
+ model_->s16_block_exponent());
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ // TODO(atgoo@github.com) Write only 16 bits.
+ writer_.WriteUnencoded(word);
+ } else {
+ assert(0);
+ }
+ } else {
+ assert(operand.number_bit_width == 64);
+ const uint64_t word =
+ uint64_t(instruction.word(operand.offset)) |
+ (uint64_t(instruction.word(operand.offset + 1)) << 32);
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ writer_.WriteVariableWidthU64(word, model_->u64_chunk_length());
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int64_t val = 0;
+ std::memcpy(&val, &word, 8);
+ writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(),
+ model_->s64_block_exponent());
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ writer_.WriteUnencoded(word);
+ } else {
+ assert(0);
+ }
+ }
+}
+
+spv_result_t MarkvEncoder::EncodeInstruction(
+ const spv_parsed_instruction_t& inst) {
+ const spv_result_t validation_result = UpdateValidationState(inst);
+ if (validation_result != SPV_SUCCESS)
+ return validation_result;
+
+ bool result_id_was_forward_declared = false;
+ if (all_known_ids_.count(inst.result_id)) {
+ // Result id of the instruction was forward declared.
+ // Write a service opcode to signal this to the decoder.
+ writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId,
+ model_->opcode_chunk_length());
+ result_id_was_forward_declared = true;
+ }
+
+ const Instruction& instruction = GetCurrentInstruction();
+ const auto& operands = instruction.operands();
+
+ LogDisassemblyInstruction();
+
+ // Write opcode.
+ writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length());
+
+ if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) {
+ // If the opcode has a variable number of operands, encode the number of
+ // operands with the instruction.
+
+ if (logger_)
+ logger_->AppendWhitespaces(kCommentNumWhitespaces);
+
+ writer_.WriteVariableWidthU16(inst.num_operands,
+ model_->num_operands_chunk_length());
+ }
+
+ // Write operands.
+ for (const auto& operand : operands) {
+ if (operand.type == SPV_OPERAND_TYPE_RESULT_ID &&
+ !result_id_was_forward_declared) {
+ // Register the id, but don't encode it.
+ GetIdIndex(instruction.word(operand.offset));
+ continue;
+ }
+
+ if (logger_)
+ logger_->AppendWhitespaces(kCommentNumWhitespaces);
+
+ if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) {
+ EncodeLiteralNumber(instruction, operand);
+ } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) {
+ const char* src =
+ reinterpret_cast<const char*>(&instruction.words()[operand.offset]);
+ const size_t length = spv_strnlen_s(src, operand.num_words * 4);
+ if (length == operand.num_words * 4)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to find terminal character of literal string";
+ for (size_t i = 0; i < length + 1; ++i)
+ writer_.WriteUnencoded(src[i]);
+ } else if (spvIsIdType(operand.type)) {
+ const uint16_t id_index = GetIdIndex(instruction.word(operand.offset));
+ writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length());
+ } else {
+ for (int i = 0; i < operand.num_words; ++i) {
+ const uint32_t word = instruction.word(operand.offset + i);
+ EncodeOperandWord(operand.type, word);
+ }
+ }
+ }
+
+ AddByteBreakIfAgreed();
+
+ if (logger_) {
+ logger_->NewLine();
+ logger_->NewLine();
+ }
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::DecodeLiteralNumber(
+ const spv_parsed_operand_t& operand) {
+ if (operand.number_bit_width == 32) {
+ uint32_t word = 0;
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal U32";
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int32_t val = 0;
+ if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(),
+ model_->s32_block_exponent()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal S32";
+ std::memcpy(&word, &val, 4);
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ if (!reader_.ReadUnencoded(&word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal F32";
+ } else {
+ assert(0);
+ }
+ spirv_.push_back(word);
+ } else if (operand.number_bit_width == 16) {
+ uint32_t word = 0;
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ uint16_t val = 0;
+ if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal U16";
+ word = val;
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int16_t val = 0;
+ if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(),
+ model_->s16_block_exponent()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal S16";
+ // Int16 is stored as int32 in SPIR-V, not as bits.
+ int32_t val32 = val;
+ std::memcpy(&word, &val32, 4);
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ uint16_t word16 = 0;
+ if (!reader_.ReadUnencoded(&word16))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal F16";
+ word = word16;
+ } else {
+ assert(0);
+ }
+ spirv_.push_back(word);
+ } else {
+ assert(operand.number_bit_width == 64);
+ uint64_t word = 0;
+ if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
+ if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal U64";
+ } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
+ int64_t val = 0;
+ if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
+ model_->s64_block_exponent()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal S64";
+ std::memcpy(&word, &val, 8);
+ } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
+ if (!reader_.ReadUnencoded(&word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal F64";
+ } else {
+ assert(0);
+ }
+ spirv_.push_back(static_cast<uint32_t>(word));
+ spirv_.push_back(static_cast<uint32_t>(word >> 32));
+ }
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
+ const bool header_read_success =
+ reader_.ReadUnencoded(&header_.magic_number) &&
+ reader_.ReadUnencoded(&header_.markv_version) &&
+ reader_.ReadUnencoded(&header_.markv_model) &&
+ reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
+ reader_.ReadUnencoded(&header_.spirv_version) &&
+ reader_.ReadUnencoded(&header_.spirv_generator);
+
+ if (!header_read_success)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Unable to read MARK-V header";
+
+ assert(header_.magic_number == kMarkvMagicNumber);
+ assert(header_.markv_length_in_bits > 0);
+
+ if (header_.magic_number != kMarkvMagicNumber)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "MARK-V binary has incorrect magic number";
+
+ // TODO(atgoo@github.com): Print version strings.
+ if (header_.markv_version != GetMarkvVersion())
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "MARK-V binary and the codec have different versions";
+
+ spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
+ spirv_.resize(5, 0);
+ spirv_[0] = kSpirvMagicNumber;
+ spirv_[1] = header_.spirv_version;
+ spirv_[2] = header_.spirv_generator;
+
+ while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
+ spv_parsed_instruction_t inst = {};
+ const spv_result_t decode_result = DecodeInstruction(&inst);
+ if (decode_result != SPV_SUCCESS)
+ return decode_result;
+
+ const spv_result_t validation_result = UpdateValidationState(inst);
+ if (validation_result != SPV_SUCCESS)
+ return validation_result;
+ }
+
+
+ if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
+ !reader_.OnlyZeroesLeft()) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "MARK-V binary has wrong stated bit length "
+ << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
+ }
+
+ // Decoding of the module is finished, validation state should have correct
+ // id bound.
+ spirv_[3] = vstate_.getIdBound();
+
+ *spirv_binary = std::move(spirv_);
+ return SPV_SUCCESS;
+}
+
+// TODO(atgoo@github.com): The implementation borrows heavily from
+// Parser::parseOperand.
+// Consider coupling them together in some way once MARK-V codec is more mature.
+// For now it's better to keep the code independent for experimentation
+// purposes.
+spv_result_t MarkvDecoder::DecodeOperand(
+ size_t instruction_offset, size_t operand_offset,
+ spv_parsed_instruction_t* inst, const spv_operand_type_t type,
+ spv_operand_pattern_t* expected_operands,
+ bool read_result_id) {
+ const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+
+ spv_parsed_operand_t parsed_operand;
+ memset(&parsed_operand, 0, sizeof(parsed_operand));
+
+ assert((operand_offset >> 16) == 0);
+ parsed_operand.offset = static_cast<uint16_t>(operand_offset);
+ parsed_operand.type = type;
+
+ // Set default values, may be updated later.
+ parsed_operand.number_kind = SPV_NUMBER_NONE;
+ parsed_operand.number_bit_width = 0;
+
+ const size_t first_word_index = spirv_.size();
+
+ switch (type) {
+ case SPV_OPERAND_TYPE_TYPE_ID: {
+ if (!DecodeId(&inst->type_id)) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read type_id";
+ }
+
+ if (inst->type_id == 0)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0";
+
+ spirv_.push_back(inst->type_id);
+ vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1));
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_RESULT_ID: {
+ if (read_result_id) {
+ if (!DecodeId(&inst->result_id))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read result_id";
+ } else {
+ inst->result_id = vstate_.getIdBound();
+ vstate_.setIdBound(inst->result_id + 1);
+ move_to_front_ids_.push_front(inst->result_id);
+ }
+
+ spirv_.push_back(inst->result_id);
+
+ // Save the result ID to type ID mapping.
+ // In the grammar, type ID always appears before result ID.
+ // A regular value maps to its type. Some instructions (e.g. OpLabel)
+ // have no type Id, and will map to 0. The result Id for a
+ // type-generating instruction (e.g. OpTypeInt) maps to itself.
+ auto insertion_result = id_to_type_id_.emplace(
+ inst->result_id,
+ spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id);
+ if(!insertion_result.second) {
+ return vstate_.diag(SPV_ERROR_INVALID_ID)
+ << "Unexpected behavior: id->type_id pair was already registered";
+ }
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_ID:
+ case SPV_OPERAND_TYPE_OPTIONAL_ID:
+ case SPV_OPERAND_TYPE_SCOPE_ID:
+ case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
+ uint32_t id = 0;
+ if (!DecodeId(&id))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id";
+
+ if (id == 0)
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
+
+ spirv_.push_back(id);
+ vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1));
+
+ if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
+
+ parsed_operand.type = SPV_OPERAND_TYPE_ID;
+
+ if (opcode == SpvOpExtInst && parsed_operand.offset == 3) {
+ // TODO(atgoo@github.com) Work in progress.
+ assert(0 && "Not implemented");
+ }
+ }
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
+ // TODO(atgoo@github.com) Work in progress.
+ assert(0 && "Not implemented");
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_LITERAL_INTEGER:
+ case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
+ // These are regular single-word literal integer operands.
+ // Post-parsing validation should check the range of the parsed value.
+ parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
+ // It turns out they are always unsigned integers!
+ parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT;
+ parsed_operand.number_bit_width = 32;
+
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal integer";
+
+ spirv_.push_back(word);
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
+ case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER:
+ parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
+ if (opcode == SpvOpSwitch) {
+ // The literal operands have the same type as the value
+ // referenced by the selector Id.
+ const uint32_t selector_id = spirv_.at(instruction_offset + 1);
+ const auto type_id_iter = id_to_type_id_.find(selector_id);
+ if (type_id_iter == id_to_type_id_.end() ||
+ type_id_iter->second == 0) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid OpSwitch: selector id " << selector_id
+ << " has no type";
+ }
+ uint32_t type_id = type_id_iter->second;
+
+ if (selector_id == type_id) {
+ // Recall that by convention, a result ID that is a type definition
+ // maps to itself.
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid OpSwitch: selector id " << selector_id
+ << " is a type, not a value";
+ }
+ if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id))
+ return error;
+ if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT &&
+ parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid OpSwitch: selector id " << selector_id
+ << " is not a scalar integer";
+ }
+ } else {
+ assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
+ // The literal number type is determined by the type Id for the
+ // constant.
+ assert(inst->type_id);
+ if (auto error =
+ SetNumericTypeInfoForType(&parsed_operand, inst->type_id))
+ return error;
+ }
+
+ if (auto error = DecodeLiteralNumber(parsed_operand))
+ return error;
+
+ break;
+
+ case SPV_OPERAND_TYPE_LITERAL_STRING:
+ case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
+ parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING;
+ std::vector<char> str;
+ // The loop is expected to terminate once we encounter '\0' or exhaust
+ // the bit stream.
+ while (true) {
+ char ch = 0;
+ if (!reader_.ReadUnencoded(&ch))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read literal string";
+
+ str.push_back(ch);
+
+ if (ch == '\0')
+ break;
+ }
+
+ while (str.size() % 4 != 0)
+ str.push_back('\0');
+
+ spirv_.resize(spirv_.size() + str.size() / 4);
+ std::memcpy(&spirv_[first_word_index], str.data(), str.size());
+
+ if (SpvOpExtInstImport == opcode) {
+ // TODO(atgoo@github.com) Work in progress.
+ assert(0 && "Not implemented");
+ }
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_CAPABILITY:
+ case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
+ case SPV_OPERAND_TYPE_EXECUTION_MODEL:
+ case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
+ case SPV_OPERAND_TYPE_MEMORY_MODEL:
+ case SPV_OPERAND_TYPE_EXECUTION_MODE:
+ case SPV_OPERAND_TYPE_STORAGE_CLASS:
+ case SPV_OPERAND_TYPE_DIMENSIONALITY:
+ case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
+ case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
+ case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
+ case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
+ case SPV_OPERAND_TYPE_LINKAGE_TYPE:
+ case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
+ case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
+ case SPV_OPERAND_TYPE_DECORATION:
+ case SPV_OPERAND_TYPE_BUILT_IN:
+ case SPV_OPERAND_TYPE_GROUP_OPERATION:
+ case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
+ case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
+ // A single word that is a plain enum value.
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read enum";
+
+ spirv_.push_back(word);
+
+ // Map an optional operand type to its corresponding concrete type.
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
+ parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
+
+ spv_operand_desc entry;
+ if (grammar_.lookupOperand(type, word, &entry)) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid "
+ << spvOperandTypeStr(parsed_operand.type)
+ << " operand: " << word;
+ }
+
+ // Prepare to accept operands to this operand, if needed.
+ spvPrependOperandTypes(entry->operandTypes, expected_operands);
+ break;
+ }
+
+ case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
+ case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
+ case SPV_OPERAND_TYPE_LOOP_CONTROL:
+ case SPV_OPERAND_TYPE_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
+ case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
+ case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
+ // This operand is a mask.
+ uint32_t word = 0;
+ if (!DecodeOperandWord(type, &word))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read " << spvOperandTypeStr(type)
+ << " for " << spvOpcodeString(SpvOp(inst->opcode));
+
+ spirv_.push_back(word);
+
+ // Map an optional operand type to its corresponding concrete type.
+ if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
+ parsed_operand.type = SPV_OPERAND_TYPE_IMAGE;
+ else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
+ parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
+
+ // Check validity of set mask bits. Also prepare for operands for those
+ // masks if they have any. To get operand order correct, scan from
+ // MSB to LSB since we can only prepend operands to a pattern.
+ // The only case in the grammar where you have more than one mask bit
+ // having an operand is for image operands. See SPIR-V 3.14 Image
+ // Operands.
+ uint32_t remaining_word = word;
+ for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
+ if (remaining_word & mask) {
+ spv_operand_desc entry;
+ if (grammar_.lookupOperand(type, mask, &entry)) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Invalid " << spvOperandTypeStr(parsed_operand.type)
+ << " operand: " << word << " has invalid mask component "
+ << mask;
+ }
+ remaining_word ^= mask;
+ spvPrependOperandTypes(entry->operandTypes, expected_operands);
+ }
+ }
+ if (word == 0) {
+ // An all-zeroes mask *might* also be valid.
+ spv_operand_desc entry;
+ if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
+ // Prepare for its operands, if any.
+ spvPrependOperandTypes(entry->operandTypes, expected_operands);
+ }
+ }
+ break;
+ }
+ default:
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Internal error: Unhandled operand type: " << type;
+ }
+
+ parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index);
+
+ assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type));
+ assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type));
+
+ parsed_operands_.push_back(parsed_operand);
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) {
+ parsed_operands_.clear();
+ const size_t instruction_offset = spirv_.size();
+
+ bool read_result_id = false;
+
+ while (true) {
+ uint32_t word = 0;
+ if (!reader_.ReadVariableWidthU32(&word,
+ model_->opcode_chunk_length())) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read opcode of instruction";
+ }
+
+ if (word >= kMarkvFirstOpcode) {
+ if (word == kMarkvOpNextInstructionEncodesResultId) {
+ read_result_id = true;
+ } else {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Encountered unknown MARK-V opcode";
+ }
+ } else {
+ inst->opcode = static_cast<uint16_t>(word);
+ break;
+ }
+ }
+
+ const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
+
+ // Opcode/num_words placeholder, the word will be filled in later.
+ spirv_.push_back(0);
+
+ spv_opcode_desc opcode_desc;
+ if (grammar_.lookupOpcode(opcode, &opcode_desc)
+ != SPV_SUCCESS) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
+ }
+
+ spv_operand_pattern_t expected_operands(
+ opcode_desc->operandTypes,
+ opcode_desc->operandTypes + opcode_desc->numTypes);
+
+ if (!OpcodeHasFixedNumberOfOperands(opcode)) {
+ if (!reader_.ReadVariableWidthU16(&inst->num_operands,
+ model_->num_operands_chunk_length()))
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read num_operands of instruction";
+ } else {
+ inst->num_operands = static_cast<uint16_t>(expected_operands.size());
+ }
+
+ for (size_t operand_index = 0;
+ operand_index < static_cast<size_t>(inst->num_operands);
+ ++operand_index) {
+ assert(!expected_operands.empty());
+ const spv_operand_type_t type =
+ spvTakeFirstMatchableOperand(&expected_operands);
+
+ const size_t operand_offset = spirv_.size() - instruction_offset;
+
+ const spv_result_t decode_result =
+ DecodeOperand(instruction_offset, operand_offset, inst, type,
+ &expected_operands, read_result_id);
+
+ if (decode_result != SPV_SUCCESS)
+ return decode_result;
+ }
+
+ assert(inst->num_operands == parsed_operands_.size());
+
+ // Only valid while spirv_ and parsed_operands_ remain unchanged.
+ inst->words = &spirv_[instruction_offset];
+ inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
+ inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset);
+ spirv_[instruction_offset] =
+ spvOpcodeMake(inst->num_words, SpvOp(inst->opcode));
+
+
+ assert(inst->num_words == std::accumulate(
+ parsed_operands_.begin(), parsed_operands_.end(), 1,
+ [](size_t num_words, const spv_parsed_operand_t& operand) {
+ return num_words += operand.num_words;
+ }) && "num_words in instruction doesn't correspond to the sum of num_words"
+ "in the operands");
+
+ RecordNumberType(*inst);
+
+ if (!ReadToByteBreakIfAgreed())
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Failed to read to byte break";
+
+ return SPV_SUCCESS;
+}
+
+spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
+ spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
+ assert(type_id != 0);
+ auto type_info_iter = type_id_to_number_type_info_.find(type_id);
+ if (type_info_iter == type_id_to_number_type_info_.end()) {
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Type Id " << type_id << " is not a type";
+ }
+
+ const NumberType& info = type_info_iter->second;
+ if (info.type == SPV_NUMBER_NONE) {
+ // This is a valid type, but for something other than a scalar number.
+ return vstate_.diag(SPV_ERROR_INVALID_BINARY)
+ << "Type Id " << type_id << " is not a scalar numeric type";
+ }
+
+ parsed_operand->number_kind = info.type;
+ parsed_operand->number_bit_width = info.bit_width;
+ // Round up the word count.
+ parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
+ return SPV_SUCCESS;
+}
+
+void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) {
+ const SpvOp opcode = static_cast<SpvOp>(inst.opcode);
+ if (spvOpcodeGeneratesType(opcode)) {
+ NumberType info = {SPV_NUMBER_NONE, 0};
+ if (SpvOpTypeInt == opcode) {
+ info.bit_width = inst.words[inst.operands[1].offset];
+ info.type = inst.words[inst.operands[2].offset] ?
+ SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT;
+ } else if (SpvOpTypeFloat == opcode) {
+ info.bit_width = inst.words[inst.operands[1].offset];
+ info.type = SPV_NUMBER_FLOATING;
+ }
+ // The *result* Id of a type generating instruction is the type Id.
+ type_id_to_number_type_info_[inst.result_id] = info;
+ }
+}
+
+spv_result_t EncodeHeader(
+ void* user_data, spv_endianness_t endian, uint32_t magic,
+ uint32_t version, uint32_t generator, uint32_t id_bound,
+ uint32_t schema) {
+ MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
+ return encoder->EncodeHeader(
+ endian, magic, version, generator, id_bound, schema);
+}
+
+spv_result_t EncodeInstruction(
+ void* user_data, const spv_parsed_instruction_t* inst) {
+ MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
+ return encoder->EncodeInstruction(*inst);
+}
+
+} // namespace
+
+spv_result_t spvSpirvToMarkv(spv_const_context context,
+ const uint32_t* spirv_words,
+ const size_t spirv_num_words,
+ spv_const_markv_encoder_options options,
+ spv_markv_binary* markv_binary,
+ spv_text* comments, spv_diagnostic* diagnostic) {
+ spv_context_t hijack_context = *context;
+ if (diagnostic) {
+ *diagnostic = nullptr;
+ libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
+ }
+
+ spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words};
+
+ spv_endianness_t endian;
+ spv_position_t position = {};
+ if (spvBinaryEndianness(&spirv_binary, &endian)) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Invalid SPIR-V magic number.";
+ }
+
+ spv_header_t header;
+ if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Invalid SPIR-V header.";
+ }
+
+ MarkvEncoder encoder(&hijack_context, options);
+
+ if (comments) {
+ encoder.CreateCommentsLogger();
+
+ spv_text text = nullptr;
+ if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words,
+ SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr)
+ != SPV_SUCCESS) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Failed to disassemble SPIR-V binary.";
+ }
+ assert(text);
+ encoder.SetDisassembly(std::string(text->str, text->length));
+ spvTextDestroy(text);
+ }
+
+ if (spvBinaryParse(
+ &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader,
+ EncodeInstruction, diagnostic) != SPV_SUCCESS) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Unable to encode to MARK-V.";
+ }
+
+ if (comments)
+ *comments = CreateSpvText(encoder.GetComments());
+
+ *markv_binary = encoder.GetMarkvBinary();
+ return SPV_SUCCESS;
+}
+
+spv_result_t spvMarkvToSpirv(spv_const_context context,
+ const uint8_t* markv_data,
+ size_t markv_size_bytes,
+ spv_const_markv_decoder_options options,
+ spv_binary* spirv_binary,
+ spv_text* /* comments */, spv_diagnostic* diagnostic) {
+ spv_position_t position = {};
+ spv_context_t hijack_context = *context;
+ if (diagnostic) {
+ *diagnostic = nullptr;
+ libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic);
+ }
+
+ MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options);
+
+ std::vector<uint32_t> words;
+
+ if (decoder.DecodeModule(&words) != SPV_SUCCESS) {
+ return libspirv::DiagnosticStream(position, hijack_context.consumer,
+ SPV_ERROR_INVALID_BINARY)
+ << "Unable to decode MARK-V.";
+ }
+
+ assert(!words.empty());
+
+ *spirv_binary = new spv_binary_t();
+ (*spirv_binary)->code = new uint32_t[words.size()];
+ (*spirv_binary)->wordCount = words.size();
+ std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size());
+
+ return SPV_SUCCESS;
+}
+
+void spvMarkvBinaryDestroy(spv_markv_binary binary) {
+ if (!binary) return;
+ delete[] binary->data;
+ delete binary;
+}
+
+spv_markv_encoder_options spvMarkvEncoderOptionsCreate() {
+ return new spv_markv_encoder_options_t;
+}
+
+void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) {
+ delete options;
+}
+
+spv_markv_decoder_options spvMarkvDecoderOptionsCreate() {
+ return new spv_markv_decoder_options_t;
+}
+
+void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) {
+ delete options;
+}