Use variant instead of creating an object for literals

This commit is contained in:
ReinUsesLisp 2018-10-28 13:44:12 -03:00
parent 8f8115d397
commit 00fc8daf56
12 changed files with 146 additions and 144 deletions

View file

@ -11,6 +11,7 @@
#include <optional> #include <optional>
#include <set> #include <set>
#include <spirv/unified1/spirv.hpp11> #include <spirv/unified1/spirv.hpp11>
#include <variant>
#include <vector> #include <vector>
namespace Sirit { namespace Sirit {
@ -20,7 +21,9 @@ constexpr std::uint32_t GENERATOR_MAGIC_NUMBER = 0;
class Op; class Op;
class Operand; class Operand;
typedef const Op* Ref; using Literal = std::variant<std::uint32_t, std::uint64_t, std::int32_t,
std::int64_t, float, double>;
using Ref = const Op*;
class Module { class Module {
public: public:
@ -135,7 +138,7 @@ class Module {
Ref ConstantFalse(Ref result_type); Ref ConstantFalse(Ref result_type);
/// Returns a numeric scalar constant. /// Returns a numeric scalar constant.
Ref Constant(Ref result_type, Operand* literal); Ref Constant(Ref result_type, const Literal& literal);
/// Returns a numeric scalar constant. /// Returns a numeric scalar constant.
Ref ConstantComposite(Ref result_type, Ref ConstantComposite(Ref result_type,
@ -201,18 +204,11 @@ class Module {
/// Add a decoration to target. /// Add a decoration to target.
Ref Decorate(Ref target, spv::Decoration decoration, Ref Decorate(Ref target, spv::Decoration decoration,
const std::vector<Operand*>& literals = {}); const std::vector<Literal>& literals = {});
Ref MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration, Ref MemberDecorate(Ref structure_type, Literal member,
const std::vector<Operand*>& literals = {}); spv::Decoration decoration,
const std::vector<Literal>& literals = {});
// Literals
static Operand* Literal(std::uint32_t value);
static Operand* Literal(std::uint64_t value);
static Operand* Literal(std::int32_t value);
static Operand* Literal(std::int64_t value);
static Operand* Literal(float value);
static Operand* Literal(double value);
private: private:
Ref AddCode(Op* op); Ref AddCode(Op* op);

View file

@ -7,7 +7,6 @@ add_library(sirit
stream.h stream.h
operand.cpp operand.cpp
operand.h operand.h
literal.cpp
literal-number.cpp literal-number.cpp
literal-number.h literal-number.h
literal-string.cpp literal-string.cpp

View file

@ -10,21 +10,22 @@
namespace Sirit { namespace Sirit {
Ref Module::Decorate(Ref target, spv::Decoration decoration, Ref Module::Decorate(Ref target, spv::Decoration decoration,
const std::vector<Operand*>& literals) { const std::vector<Literal>& literals) {
auto op{new Op(spv::Op::OpDecorate)}; auto op{new Op(spv::Op::OpDecorate)};
op->Add(target); op->Add(target);
AddEnum(op, decoration); AddEnum(op, decoration);
op->Sink(literals); op->Add(literals);
return AddAnnotation(op); return AddAnnotation(op);
} }
Ref Module::MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration, Ref Module::MemberDecorate(Ref structure_type, Literal member,
const std::vector<Operand*>& literals) { spv::Decoration decoration,
const std::vector<Literal>& literals) {
auto op{new Op(spv::Op::OpMemberDecorate)}; auto op{new Op(spv::Op::OpMemberDecorate)};
op->Add(structure_type); op->Add(structure_type);
op->Sink(member); op->Add(member);
AddEnum(op, decoration); AddEnum(op, decoration);
op->Sink(literals); op->Add(literals);
return AddAnnotation(op); return AddAnnotation(op);
} }

View file

@ -4,9 +4,9 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include <cassert>
#include "sirit/sirit.h"
#include "insts.h" #include "insts.h"
#include "sirit/sirit.h"
#include <cassert>
namespace Sirit { namespace Sirit {
@ -18,20 +18,23 @@ Ref Module::ConstantFalse(Ref result_type) {
return AddDeclaration(new Op(spv::Op::OpConstantFalse, bound, result_type)); return AddDeclaration(new Op(spv::Op::OpConstantFalse, bound, result_type));
} }
Ref Module::Constant(Ref result_type, Operand* literal) { Ref Module::Constant(Ref result_type, const Literal& literal) {
auto op{new Op(spv::Op::OpConstant, bound, result_type)}; auto op{new Op(spv::Op::OpConstant, bound, result_type)};
op->Add(literal); op->Add(literal);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::ConstantComposite(Ref result_type, const std::vector<Ref>& constituents) { Ref Module::ConstantComposite(Ref result_type,
const std::vector<Ref>& constituents) {
auto op{new Op(spv::Op::OpConstantComposite, bound, result_type)}; auto op{new Op(spv::Op::OpConstantComposite, bound, result_type)};
op->Add(constituents); op->Add(constituents);
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::ConstantSampler(Ref result_type, spv::SamplerAddressingMode addressing_mode, Ref Module::ConstantSampler(Ref result_type,
bool normalized, spv::SamplerFilterMode filter_mode) { spv::SamplerAddressingMode addressing_mode,
bool normalized,
spv::SamplerFilterMode filter_mode) {
AddCapability(spv::Capability::LiteralSampler); AddCapability(spv::Capability::LiteralSampler);
AddCapability(spv::Capability::Kernel); AddCapability(spv::Capability::Kernel);
auto op{new Op(spv::Op::OpConstantSampler, bound, result_type)}; auto op{new Op(spv::Op::OpConstantSampler, bound, result_type)};

View file

@ -4,8 +4,8 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include "sirit/sirit.h"
#include "insts.h" #include "insts.h"
#include "sirit/sirit.h"
namespace Sirit { namespace Sirit {

View file

@ -4,12 +4,13 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include "sirit/sirit.h"
#include "insts.h" #include "insts.h"
#include "sirit/sirit.h"
namespace Sirit { namespace Sirit {
Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control, Ref Module::LoopMerge(Ref merge_block, Ref continue_target,
spv::LoopControlMask loop_control,
const std::vector<Ref>& literals) { const std::vector<Ref>& literals) {
auto op{new Op(spv::Op::OpLoopMerge)}; auto op{new Op(spv::Op::OpLoopMerge)};
op->Add(merge_block); op->Add(merge_block);
@ -19,16 +20,15 @@ Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask
return AddCode(op); return AddCode(op);
} }
Ref Module::SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control) { Ref Module::SelectionMerge(Ref merge_block,
spv::SelectionControlMask selection_control) {
auto op{new Op(spv::Op::OpSelectionMerge)}; auto op{new Op(spv::Op::OpSelectionMerge)};
op->Add(merge_block); op->Add(merge_block);
AddEnum(op, selection_control); AddEnum(op, selection_control);
return AddCode(op); return AddCode(op);
} }
Ref Module::Label() { Ref Module::Label() { return AddCode(spv::Op::OpLabel, bound++); }
return AddCode(spv::Op::OpLabel, bound++);
}
Ref Module::Branch(Ref target_label) { Ref Module::Branch(Ref target_label) {
auto op{new Op(spv::Op::OpBranch)}; auto op{new Op(spv::Op::OpBranch)};
@ -37,20 +37,19 @@ Ref Module::Branch(Ref target_label) {
} }
Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label, Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label,
std::uint32_t true_weight, std::uint32_t false_weight) { std::uint32_t true_weight,
std::uint32_t false_weight) {
auto op{new Op(spv::Op::OpBranchConditional)}; auto op{new Op(spv::Op::OpBranchConditional)};
op->Add(condition); op->Add(condition);
op->Add(true_label); op->Add(true_label);
op->Add(false_label); op->Add(false_label);
if (true_weight != 0 || false_weight != 0) { if (true_weight != 0 || false_weight != 0) {
op->Add(Literal(true_weight)); op->Add(true_weight);
op->Add(Literal(false_weight)); op->Add(false_weight);
} }
return AddCode(op); return AddCode(op);
} }
Ref Module::Return() { Ref Module::Return() { return AddCode(spv::Op::OpReturn); }
return AddCode(spv::Op::OpReturn);
}
} // namespace Sirit } // namespace Sirit

View file

@ -4,20 +4,19 @@
* Lesser General Public License version 2.1 or any later version. * Lesser General Public License version 2.1 or any later version.
*/ */
#include "sirit/sirit.h"
#include "insts.h" #include "insts.h"
#include "sirit/sirit.h"
namespace Sirit { namespace Sirit {
Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) { Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control,
Ref function_type) {
auto op{new Op{spv::Op::OpFunction, bound++, result_type}}; auto op{new Op{spv::Op::OpFunction, bound++, result_type}};
op->Add(static_cast<u32>(function_control)); op->Add(static_cast<u32>(function_control));
op->Add(function_type); op->Add(function_type);
return AddCode(op); return AddCode(op);
} }
Ref Module::FunctionEnd() { Ref Module::FunctionEnd() { return AddCode(spv::Op::OpFunctionEnd); }
return AddCode(spv::Op::OpFunctionEnd);
}
} // namespace Sirit } // namespace Sirit

View file

@ -7,8 +7,8 @@
#include <cassert> #include <cassert>
#include <optional> #include <optional>
#include "sirit/sirit.h"
#include "insts.h" #include "insts.h"
#include "sirit/sirit.h"
namespace Sirit { namespace Sirit {
@ -62,68 +62,68 @@ Ref Module::TypeMatrix(Ref column_type, int column_count) {
return AddDeclaration(op); return AddDeclaration(op);
} }
Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed,
int sampled, spv::ImageFormat image_format, bool ms, int sampled, spv::ImageFormat image_format,
std::optional<spv::AccessQualifier> access_qualifier) { std::optional<spv::AccessQualifier> access_qualifier) {
switch (dim) { switch (dim) {
case spv::Dim::Dim1D: case spv::Dim::Dim1D:
AddCapability(spv::Capability::Sampled1D); AddCapability(spv::Capability::Sampled1D);
break; break;
case spv::Dim::Cube: case spv::Dim::Cube:
AddCapability(spv::Capability::Shader); AddCapability(spv::Capability::Shader);
break; break;
case spv::Dim::Rect: case spv::Dim::Rect:
AddCapability(spv::Capability::SampledRect); AddCapability(spv::Capability::SampledRect);
break; break;
case spv::Dim::Buffer: case spv::Dim::Buffer:
AddCapability(spv::Capability::SampledBuffer); AddCapability(spv::Capability::SampledBuffer);
break; break;
case spv::Dim::SubpassData: case spv::Dim::SubpassData:
AddCapability(spv::Capability::InputAttachment); AddCapability(spv::Capability::InputAttachment);
break; break;
} }
switch (image_format) { switch (image_format) {
case spv::ImageFormat::Rgba32f: case spv::ImageFormat::Rgba32f:
case spv::ImageFormat::Rgba16f: case spv::ImageFormat::Rgba16f:
case spv::ImageFormat::R32f: case spv::ImageFormat::R32f:
case spv::ImageFormat::Rgba8: case spv::ImageFormat::Rgba8:
case spv::ImageFormat::Rgba8Snorm: case spv::ImageFormat::Rgba8Snorm:
case spv::ImageFormat::Rgba32i: case spv::ImageFormat::Rgba32i:
case spv::ImageFormat::Rgba16i: case spv::ImageFormat::Rgba16i:
case spv::ImageFormat::Rgba8i: case spv::ImageFormat::Rgba8i:
case spv::ImageFormat::R32i: case spv::ImageFormat::R32i:
case spv::ImageFormat::Rgba32ui: case spv::ImageFormat::Rgba32ui:
case spv::ImageFormat::Rgba16ui: case spv::ImageFormat::Rgba16ui:
case spv::ImageFormat::Rgba8ui: case spv::ImageFormat::Rgba8ui:
case spv::ImageFormat::R32ui: case spv::ImageFormat::R32ui:
AddCapability(spv::Capability::Shader); AddCapability(spv::Capability::Shader);
break; break;
case spv::ImageFormat::Rg32f: case spv::ImageFormat::Rg32f:
case spv::ImageFormat::Rg16f: case spv::ImageFormat::Rg16f:
case spv::ImageFormat::R11fG11fB10f: case spv::ImageFormat::R11fG11fB10f:
case spv::ImageFormat::R16f: case spv::ImageFormat::R16f:
case spv::ImageFormat::Rgba16: case spv::ImageFormat::Rgba16:
case spv::ImageFormat::Rgb10A2: case spv::ImageFormat::Rgb10A2:
case spv::ImageFormat::Rg16: case spv::ImageFormat::Rg16:
case spv::ImageFormat::Rg8: case spv::ImageFormat::Rg8:
case spv::ImageFormat::R16: case spv::ImageFormat::R16:
case spv::ImageFormat::R8: case spv::ImageFormat::R8:
case spv::ImageFormat::Rgba16Snorm: case spv::ImageFormat::Rgba16Snorm:
case spv::ImageFormat::Rg16Snorm: case spv::ImageFormat::Rg16Snorm:
case spv::ImageFormat::Rg8Snorm: case spv::ImageFormat::Rg8Snorm:
case spv::ImageFormat::Rg32i: case spv::ImageFormat::Rg32i:
case spv::ImageFormat::Rg16i: case spv::ImageFormat::Rg16i:
case spv::ImageFormat::Rg8i: case spv::ImageFormat::Rg8i:
case spv::ImageFormat::R16i: case spv::ImageFormat::R16i:
case spv::ImageFormat::R8i: case spv::ImageFormat::R8i:
case spv::ImageFormat::Rgb10a2ui: case spv::ImageFormat::Rgb10a2ui:
case spv::ImageFormat::Rg32ui: case spv::ImageFormat::Rg32ui:
case spv::ImageFormat::Rg16ui: case spv::ImageFormat::Rg16ui:
case spv::ImageFormat::Rg8ui: case spv::ImageFormat::Rg8ui:
case spv::ImageFormat::R16ui: case spv::ImageFormat::R16ui:
case spv::ImageFormat::R8ui: case spv::ImageFormat::R8ui:
AddCapability(spv::Capability::StorageImageExtendedFormats); AddCapability(spv::Capability::StorageImageExtendedFormats);
break; break;
} }
auto op{new Op(spv::Op::OpTypeImage, bound)}; auto op{new Op(spv::Op::OpTypeImage, bound)};
op->Add(sampled_type); op->Add(sampled_type);
@ -179,19 +179,19 @@ Ref Module::TypeOpaque(const std::string& name) {
Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) { Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) {
switch (storage_class) { switch (storage_class) {
case spv::StorageClass::Uniform: case spv::StorageClass::Uniform:
case spv::StorageClass::Output: case spv::StorageClass::Output:
case spv::StorageClass::Private: case spv::StorageClass::Private:
case spv::StorageClass::PushConstant: case spv::StorageClass::PushConstant:
case spv::StorageClass::StorageBuffer: case spv::StorageClass::StorageBuffer:
AddCapability(spv::Capability::Shader); AddCapability(spv::Capability::Shader);
break; break;
case spv::StorageClass::Generic: case spv::StorageClass::Generic:
AddCapability(spv::Capability::GenericPointer); AddCapability(spv::Capability::GenericPointer);
break; break;
case spv::StorageClass::AtomicCounter: case spv::StorageClass::AtomicCounter:
AddCapability(spv::Capability::AtomicStorage); AddCapability(spv::Capability::AtomicStorage);
break; break;
} }
auto op{new Op(spv::Op::OpTypePointer, bound)}; auto op{new Op(spv::Op::OpTypePointer, bound)};
op->Add(static_cast<u32>(storage_class)); op->Add(static_cast<u32>(storage_class));

View file

@ -1,26 +0,0 @@
/* This file is part of the sirit project.
* Copyright (c) 2018 ReinUsesLisp
* This software may be used and distributed according to the terms of the GNU
* Lesser General Public License version 2.1 or any later version.
*/
#include "common_types.h"
#include "literal-number.h"
#include "operand.h"
#include "sirit/sirit.h"
namespace Sirit {
#define DEFINE_LITERAL(type) \
Operand* Module::Literal(type value) { \
return LiteralNumber::Create<type>(value); \
}
DEFINE_LITERAL(u32)
DEFINE_LITERAL(u64)
DEFINE_LITERAL(s32)
DEFINE_LITERAL(s64)
DEFINE_LITERAL(f32)
DEFINE_LITERAL(f64)
} // namespace Sirit

View file

@ -71,6 +71,34 @@ void Op::Sink(const std::vector<Operand*>& operands) {
} }
} }
void Op::Add(const Literal& literal) {
Operand* operand = [&]() {
switch (literal.index()) {
case 0:
return LiteralNumber::Create(std::get<0>(literal));
case 1:
return LiteralNumber::Create(std::get<1>(literal));
case 2:
return LiteralNumber::Create(std::get<2>(literal));
case 3:
return LiteralNumber::Create(std::get<3>(literal));
case 4:
return LiteralNumber::Create(std::get<4>(literal));
case 5:
return LiteralNumber::Create(std::get<5>(literal));
default:
assert(!"invalid literal type");
}
}();
Sink(operand);
}
void Op::Add(const std::vector<Literal>& literals) {
for (const auto& literal : literals) {
Add(literal);
}
}
void Op::Add(const Operand* operand) { operands.push_back(operand); } void Op::Add(const Operand* operand) { operands.push_back(operand); }
void Op::Add(u32 integer) { Sink(LiteralNumber::Create<u32>(integer)); } void Op::Add(u32 integer) { Sink(LiteralNumber::Create<u32>(integer)); }

View file

@ -31,6 +31,10 @@ class Op : public Operand {
void Sink(const std::vector<Operand*>& operands); void Sink(const std::vector<Operand*>& operands);
void Add(const Literal& literal);
void Add(const std::vector<Literal>& literals);
void Add(const Operand* operand); void Add(const Operand* operand);
void Add(u32 integer); void Add(u32 integer);

View file

@ -20,8 +20,7 @@ static void WriteEnum(Stream& stream, spv::Op opcode, T value) {
op.Write(stream); op.Write(stream);
} }
template <typename T> template <typename T> static void WriteSet(Stream& stream, const T& set) {
static void WriteSet(Stream& stream, const T& set) {
for (const auto& item : set) { for (const auto& item : set) {
item->Write(stream); item->Write(stream);
} }