From 34d215d3d8c4a9071715d9eccebb0f285f1c8434 Mon Sep 17 00:00:00 2001 From: ReinUsesLisp Date: Sat, 25 Aug 2018 20:16:37 -0300 Subject: [PATCH] Implement stuff --- CMakeLists.txt | 9 --- include/sirit/sirit.h | 94 +++++++++++++++++++++++++++ src/CMakeLists.txt | 13 +++- src/common_types.h | 29 +++++++++ src/impl.h | 15 +++++ src/operand.cpp | 86 ++++++++++++++++++++++++ src/operand.h | 65 +++++++++++++++++++ src/ref.cpp | 99 ++++++++++++++++++++++++++++ src/ref.h | 52 +++++++++++++++ src/sirit.cpp | 148 +++++++++++++++++++++++++++++++++++++++++- src/stream.cpp | 49 ++++++++++++++ src/stream.h | 34 ++++++++++ src/type.h | 33 ++++++++++ tests/main.cpp | 34 +++++++++- 14 files changed, 747 insertions(+), 13 deletions(-) create mode 100644 src/common_types.h create mode 100644 src/impl.h create mode 100644 src/operand.cpp create mode 100644 src/operand.h create mode 100644 src/ref.cpp create mode 100644 src/ref.h create mode 100644 src/stream.cpp create mode 100644 src/stream.h create mode 100644 src/type.h diff --git a/CMakeLists.txt b/CMakeLists.txt index cbc91b7..ee6144d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -82,15 +82,6 @@ else() endif() endif() -# Include Boost -if (NOT TARGET boost) - if (NOT Boost_INCLUDE_DIRS) - find_package(Boost 1.57.0 REQUIRED) - endif() - add_library(boost INTERFACE) - target_include_directories(boost SYSTEM INTERFACE ${Boost_INCLUDE_DIRS}) -endif() - # Enable unit-testing. enable_testing(true) diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index 83d1359..cc549b6 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -6,6 +6,100 @@ #pragma once +#include +#include +#include +#include +#include + namespace Sirit { +static const std::uint32_t GeneratorMagicNumber = 0; + +class Ref; + +class Module { +public: + explicit Module(); + ~Module(); + + /** + * Assembles current module into a SPIR-V stream. + * It can be called multiple times but it's recommended to copy code externally. + * @return A stream of bytes representing a SPIR-V module. + */ + std::vector Assembly() const; + + /** + * Optimizes module's IR. + * All returned references become invalid. + * @param level Level of optimization. + */ + void Optimize(int level); + + /// Adds a module capability. + void AddCapability(spv::Capability capability); + + /// Sets module memory model. + void SetMemoryModel(spv::AddressingModel addressing_model, spv::MemoryModel memory_model); + + /// Adds an entry point. + void AddEntryPoint(spv::ExecutionModel execution_model, const Ref* entry_point, + const std::string& name, const std::vector& interfaces = {}); + + /// Returns type void. + const Ref* TypeVoid(); + + /// Returns a function type. + const Ref* TypeFunction(const Ref* return_type, const std::vector& arguments = {}); + + /// Adds a reference to code block + void Add(const Ref* ref); + + /// Emits a function. + const Ref* EmitFunction(const Ref* result_type, spv::FunctionControlMask function_control, + const Ref* function_type); + + /// Emits a label. It starts a block. + const Ref* EmitLabel(); + + /// Emits a return. It ends a block. + const Ref* EmitReturn(); + + /// Emits a function end. + const Ref* EmitFunctionEnd(); + +private: + const Ref* AddCode(Ref* ref); + + const Ref* AddCode(spv::Op opcode, std::uint32_t id = UINT32_MAX); + + const Ref* AddDeclaration(Ref* ref); + + std::uint32_t bound{1}; + + std::set capabilities; + + std::set extensions; + + std::set> ext_inst_import; + + spv::AddressingModel addressing_model{spv::AddressingModel::Logical}; + spv::MemoryModel memory_model{spv::MemoryModel::GLSL450}; + + std::vector> entry_points; + + std::vector> execution_mode; + + std::vector> debug; + + std::vector> annotations; + + std::vector> declarations; + + std::vector code; + + std::vector> code_store; +}; + } // namespace Sirit diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 44696b1..5e22f29 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,6 +1,15 @@ add_library(sirit + ../include/sirit/sirit.h sirit.cpp + ref.cpp + ref.h + stream.cpp + stream.h + operand.cpp + operand.h + common_types.h ) target_include_directories(sirit - PUBLIC ../include ${SPIRV-Headers_SOURCE_DIR}/include} - PRIVATE .) + PUBLIC ../include + PRIVATE . ${SPIRV-Headers_SOURCE_DIR}/include + INTERFACE ${SPIRV-Headers_SOURCE_DIR}/include) diff --git a/src/common_types.h b/src/common_types.h new file mode 100644 index 0000000..9238409 --- /dev/null +++ b/src/common_types.h @@ -0,0 +1,29 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#pragma once + +#include +#include + +using u8 = std::uint8_t; +using u16 = std::uint16_t; +using u32 = std::uint32_t; +using u64 = std::uint64_t; +using uptr = std::uintptr_t; + +using s8 = std::int8_t; +using s16 = std::int16_t; +using s32 = std::int32_t; +using s64 = std::int64_t; +using sptr = std::intptr_t; + +using size_t = std::size_t; + +using f32 = float; +using f64 = double; +static_assert(sizeof(f32) == sizeof(u32), "f32 must be 32 bits wide"); +static_assert(sizeof(f64) == sizeof(u64), "f64 must be 64 bits wide"); diff --git a/src/impl.h b/src/impl.h new file mode 100644 index 0000000..cb53e1e --- /dev/null +++ b/src/impl.h @@ -0,0 +1,15 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#pragma once + +struct Impl final { +public: + explicit Impl(); + ~Impl(); + +private: +}; diff --git a/src/operand.cpp b/src/operand.cpp new file mode 100644 index 0000000..8103640 --- /dev/null +++ b/src/operand.cpp @@ -0,0 +1,86 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#include +#include "operand.h" + +namespace Sirit { + +Operand::Operand() {} + +Operand::~Operand() = default; + +void Operand::Fetch(Stream& stream) const { + assert(!"Fetching unimplemented operand"); +} + +u16 Operand::GetWordCount() const { + assert(!"Fetching unimplemented operand"); + return 0; +} + +bool Operand::operator==(const Operand& other) const { + return false; +} + +bool Operand::operator!=(const Operand& other) const { + return !(*this == other); +} + +OperandType Operand::GetType() const { + return operand_type; +} + +LiteralInteger::LiteralInteger(u32 integer_) + : integer(integer_) { + operand_type = OperandType::Integer; +} + +LiteralInteger::~LiteralInteger() = default; + +void LiteralInteger::Fetch(Stream& stream) const { + stream.Write(integer); +} + +u16 LiteralInteger::GetWordCount() const { + return 1; +} + +bool LiteralInteger::operator==(const Operand& other) const { + if (operand_type == other.GetType()) { + return dynamic_cast(other).integer == integer; + } + return false; +} + +LiteralString::LiteralString(const std::string& string_) + : string(string_) { + operand_type = OperandType::String; +} + +LiteralString::~LiteralString() = default; + +void LiteralString::Fetch(Stream& stream) const { + for (std::size_t i{}; i < string.size(); i++) { + stream.Write(static_cast(string[i])); + } + for (std::size_t i{}; i < 4 - (string.size() % 4); i++) { + stream.Write(static_cast(0)); + } +} + +u16 LiteralString::GetWordCount() const { + return static_cast(string.size() / 4 + 1); +} + +bool LiteralString::operator==(const Operand& other) const { + if (operand_type == other.GetType()) { + return dynamic_cast(other).string == string; + } + return false; +} + +} // namespace Sirit diff --git a/src/operand.h b/src/operand.h new file mode 100644 index 0000000..9decf5d --- /dev/null +++ b/src/operand.h @@ -0,0 +1,65 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#pragma once + +#include "stream.h" + +namespace Sirit { + +enum class OperandType { + Invalid, + Ref, + Integer, + String +}; + +class Operand { +public: + Operand(); + virtual ~Operand(); + + virtual void Fetch(Stream& stream) const; + virtual u16 GetWordCount() const; + + virtual bool operator==(const Operand& other) const; + bool operator!=(const Operand& other) const; + + OperandType GetType() const; + +protected: + OperandType operand_type{}; +}; + +class LiteralInteger : public Operand { +public: + LiteralInteger(u32 integer); + ~LiteralInteger(); + + virtual void Fetch(Stream& stream) const; + virtual u16 GetWordCount() const; + + virtual bool operator==(const Operand& other) const; + +private: + u32 integer; +}; + +class LiteralString : public Operand { +public: + LiteralString(const std::string& string); + ~LiteralString(); + + virtual void Fetch(Stream& stream) const; + virtual u16 GetWordCount() const; + + virtual bool operator==(const Operand& other) const; + +private: + std::string string; +}; + +} // namespace Sirit diff --git a/src/ref.cpp b/src/ref.cpp new file mode 100644 index 0000000..22a9403 --- /dev/null +++ b/src/ref.cpp @@ -0,0 +1,99 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#include +#include "common_types.h" +#include "operand.h" +#include "ref.h" + +namespace Sirit { + +Ref::Ref(spv::Op opcode_, u32 id_, const Ref* result_type_) + : opcode(opcode_), id(id_), result_type(result_type_) { + operand_type = OperandType::Ref; +} + +Ref::~Ref() = default; + +void Ref::Fetch(Stream& stream) const { + assert(id != UINT32_MAX); + stream.Write(id); +} + +u16 Ref::GetWordCount() const { + return 1; +} + +bool Ref::operator==(const Operand& other) const { + if (operand_type != other.GetType()) { + return false; + } + const Ref& ref = dynamic_cast(other); + if (ref.opcode == opcode && result_type == ref.result_type && + operands.size() == ref.operands.size()) { + for (std::size_t i{}; i < operands.size(); i++) { + if (*operands[i] != *ref.operands[i]) { + return false; + } + } + return true; + } + return false; +} + +void Ref::Write(Stream& stream) const { + stream.Write(static_cast(opcode)); + stream.Write(WordCount()); + + if (result_type) { + result_type->Fetch(stream); + } + if (id != UINT32_MAX) { + stream.Write(id); + } + for (const Operand* operand : operands) { + operand->Fetch(stream); + } +} + +void Ref::Add(Operand* operand) { + Add(static_cast(operand)); + operand_store.push_back(std::unique_ptr(operand)); +} + +void Ref::Add(const Operand* operand) { + operands.push_back(operand); +} + +void Ref::Add(u32 integer) { + Add(new LiteralInteger(integer)); +} + +void Ref::Add(const std::string& string) { + Add(new LiteralString(string)); +} + +void Ref::Add(const std::vector& ids) { + for (const Ref* ref : ids) { + Add(ref); + } +} + +u16 Ref::WordCount() const { + u16 count{1}; + if (result_type) { + count++; + } + if (id != UINT32_MAX) { + count++; + } + for (const Operand* operand : operands) { + count += operand->GetWordCount(); + } + return count; +} + +} // namespace Sirit diff --git a/src/ref.h b/src/ref.h new file mode 100644 index 0000000..5b6f6b6 --- /dev/null +++ b/src/ref.h @@ -0,0 +1,52 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#pragma once + +#include "sirit/sirit.h" +#include "common_types.h" +#include "operand.h" +#include "stream.h" + +namespace Sirit { + +class Ref : public Operand { +public: + explicit Ref(spv::Op opcode, u32 id = UINT32_MAX, const Ref* result_type = nullptr); + ~Ref(); + + virtual void Fetch(Stream& stream) const; + virtual u16 GetWordCount() const; + + virtual bool operator==(const Operand& other) const; + + void Write(Stream& stream) const; + + void Add(Operand* operand); + + void Add(const Operand* operand); + + void Add(u32 integer); + + void Add(const std::string& string); + + void Add(const std::vector& ids); + +private: + u16 WordCount() const; + + spv::Op opcode; + + const Ref* result_type; + + u32 id; + + std::vector operands; + + std::vector> operand_store; +}; + +} // namespace Sirit diff --git a/src/sirit.cpp b/src/sirit.cpp index 52abc93..faacfcf 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -4,4 +4,150 @@ * General Public License version 2 or any later version. */ -#include +#include +#include +#include "sirit/sirit.h" +#include "common_types.h" +#include "ref.h" +#include "stream.h" + +namespace Sirit { + +template +static void WriteEnum(Stream& stream, spv::Op op, T value) { + Ref ref{op}; + ref.Add(static_cast(value)); + ref.Write(stream); +} + +Module::Module() {} + +Module::~Module() = default; + +std::vector Module::Assembly() const { + std::vector bytes; + Stream stream{bytes}; + + stream.Write(spv::MagicNumber); + stream.Write(spv::Version); + stream.Write(GeneratorMagicNumber); + stream.Write(bound); + stream.Write(static_cast(0)); + + for (auto capability : capabilities) { + WriteEnum(stream, spv::Op::OpCapability, capability); + } + + // TODO write extensions + + // TODO write ext inst imports + + Ref memory_model_ref{spv::Op::OpMemoryModel}; + memory_model_ref.Add(static_cast(addressing_model)); + memory_model_ref.Add(static_cast(memory_model)); + memory_model_ref.Write(stream); + + for (const auto& entry_point : entry_points) { + entry_point->Write(stream); + } + + // TODO write execution mode + + // TODO write debug symbols + + // TODO write annotations + + for (const auto& decl : declarations) { + decl->Write(stream); + } + for (const auto& line : code) { + line->Write(stream); + } + + return bytes; +} + +void Module::Optimize(int level) { +} + +void Module::AddCapability(spv::Capability capability) { + capabilities.insert(capability); +} + +void Module::SetMemoryModel(spv::AddressingModel addressing_model, spv::MemoryModel memory_model) { + this->addressing_model = addressing_model; + this->memory_model = memory_model; +} + +void Module::AddEntryPoint(spv::ExecutionModel execution_model, const Ref* entry_point, + const std::string& name, const std::vector& interfaces) { + Ref* op{new Ref(spv::Op::OpEntryPoint)}; + op->Add(static_cast(execution_model)); + op->Add(entry_point); + op->Add(name); + op->Add(interfaces); + entry_points.push_back(std::unique_ptr(op)); +} + +const Ref* Module::TypeVoid() { + return AddDeclaration(new Ref(spv::Op::OpTypeVoid, bound)); +} + +const Ref* Module::TypeFunction(const Ref* return_type, const std::vector& arguments) { + Ref* type_func{new Ref(spv::Op::OpTypeFunction, bound)}; + type_func->Add(return_type); + for (const Ref* arg : arguments) { + type_func->Add(arg); + } + return AddDeclaration(type_func); +} + +void Module::Add(const Ref* ref) { + assert(ref); + code.push_back(ref); +} + +const Ref* Module::EmitFunction(const Ref* result_type, spv::FunctionControlMask function_control, + const Ref* function_type) { + Ref* op{new Ref{spv::Op::OpFunction, bound++, result_type}}; + op->Add(static_cast(function_control)); + op->Add(function_type); + return AddCode(op); +} + +const Ref* Module::EmitLabel() { + return AddCode(spv::Op::OpLabel, bound++); +} + +const Ref* Module::EmitReturn() { + return AddCode(spv::Op::OpReturn); +} + +const Ref* Module::EmitFunctionEnd() { + return AddCode(spv::Op::OpFunctionEnd); +} + +const Ref* Module::AddCode(Ref* ref) { + code_store.push_back(std::unique_ptr(ref)); + return ref; +} + +const Ref* Module::AddCode(spv::Op opcode, u32 id) { + return AddCode(new Ref{opcode, id}); +} + +const Ref* Module::AddDeclaration(Ref* ref) { + const auto& found{std::find_if(declarations.begin(), declarations.end(), [=](const auto& other) { + return *other == *ref; + })}; + if (found != declarations.end()) { + delete ref; + return found->get(); + } else { + declarations.push_back(std::unique_ptr(ref)); + bound++; + return ref; + } +} + +} // namespace Sirit diff --git a/src/stream.cpp b/src/stream.cpp new file mode 100644 index 0000000..e4dcd19 --- /dev/null +++ b/src/stream.cpp @@ -0,0 +1,49 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#include "stream.h" + +namespace Sirit { + +Stream::Stream(std::vector& bytes_) + : bytes(bytes_) {} + +Stream::~Stream() = default; + +void Stream::Write(std::string string) { + std::size_t size{string.size()}; + u8* data{reinterpret_cast(string.data())}; + for (std::size_t i{}; i < size; i++) { + Write(data[i]); + } + for (std::size_t i{}; i < 4 - size % 4; i++) { + Write(static_cast(0)); + } +} + +void Stream::Write(u64 value) { + u32* mem{reinterpret_cast(&value)}; + Write(mem[0]); + Write(mem[1]); +} + +void Stream::Write(u32 value) { + u16* mem{reinterpret_cast(&value)}; + Write(mem[0]); + Write(mem[1]); +} + +void Stream::Write(u16 value) { + u8* mem{reinterpret_cast(&value)}; + Write(mem[0]); + Write(mem[1]); +} + +void Stream::Write(u8 value) { + bytes.push_back(value); +} + +} // namespace Sirit diff --git a/src/stream.h b/src/stream.h new file mode 100644 index 0000000..44135c5 --- /dev/null +++ b/src/stream.h @@ -0,0 +1,34 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#pragma once + +#include +#include +#include "common_types.h" + +namespace Sirit { + +class Stream { +public: + explicit Stream(std::vector& bytes); + ~Stream(); + + void Write(std::string string); + + void Write(u64 value); + + void Write(u32 value); + + void Write(u16 value); + + void Write(u8 value); + +private: + std::vector& bytes; +}; + +} // namespace Sirit diff --git a/src/type.h b/src/type.h new file mode 100644 index 0000000..b7faf86 --- /dev/null +++ b/src/type.h @@ -0,0 +1,33 @@ +/* This file is part of the sirit project. + * Copyright (c) 2018 ReinUsesLisp + * This software may be used and distributed according to the terms of the GNU + * General Public License version 2 or any later version. + */ + +#pragma once + +#include +#include "sirit/sirit.h" +#include "ref.h" + +namespace Sirit { + +class TypeConstant : public Ref { +public: + Type(spv::Op opcode, u32 id, std::); + ~Type(); + + bool operator==(const Type& other) const; + +private: + /// Arguments can be type references or constants + std::vector args; + + friend Type; +}; + +using Type = TypeConstant; + +using Constant = TypeConstant; + +} // namespace Sirit diff --git a/tests/main.cpp b/tests/main.cpp index c5c520f..ab8f68d 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -5,7 +5,39 @@ */ #include +#include +#include + +class MyModule : public Sirit::Module { +public: + MyModule() {} + ~MyModule() = default; + + void Generate() { + AddCapability(spv::Capability::Shader); + SetMemoryModel(spv::AddressingModel::Logical, spv::MemoryModel::GLSL450); + + auto main_type{TypeFunction(TypeVoid())}; + auto main_func{EmitFunction(TypeVoid(), spv::FunctionControlMask::MaskNone, main_type)}; + Add(main_func); + Add(EmitLabel()); + Add(EmitReturn()); + Add(EmitFunctionEnd()); + + AddEntryPoint(spv::ExecutionModel::Vertex, main_func, "main"); + } +}; + +int main(int argc, char** argv) { + MyModule module; + module.Generate(); + + module.Optimize(2); + std::vector code{module.Assembly()}; + + FILE* file = fopen("sirit.spv", "wb"); + fwrite(code.data(), 1, code.size(), file); + fclose(file); -int main(int argc, char **argv) { return 0; }