Skip to content

Commit

Permalink
Start work on the type tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
khyperia committed Aug 19, 2020
1 parent f7ba9a7 commit db97a8e
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 60 deletions.
101 changes: 61 additions & 40 deletions rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,35 @@ use rspirv::spirv::Word;
use rustc_middle::ty::layout::TyAndLayout;
use rustc_target::abi::{Abi, FieldsShape, Primitive, Scalar, Size};

/*
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum SpirvType {
Bool(Word),
Integer(Word, u32, bool),
Float(Word, u32),
Bool,
Integer(u32, bool),
Float(u32),
// TODO: Do we fold this into Adt?
/// Zero Sized Type
ZST(Word),
/// This variant is kind of useless, but it lets us recognize Pointer(Slice(T)), etc.
// TODO: Actually recognize Pointer(Slice(T)) and generate a wide pointer.
Slice(Box<SpirvType>),
ZST,
/// This uses the rustc definition of "adt", i.e. a struct, enum, or union
Adt {
def: Word,
// TODO: enums/unions
field_types: Vec<SpirvType>,
field_types: Vec<Word>,
},
Pointer {
def: Word,
pointee: Box<SpirvType>,
Vector {
element: Word,
count: u32,
},
Array {
element: Word,
count: u32,
},
}

impl SpirvType {
pub fn def(&self) -> Word {
match *self {
SpirvType::Bool(def) => def,
SpirvType::Integer(def, _, _) => def,
SpirvType::Float(def, _) => def,
SpirvType::ZST(def) => def,
SpirvType::Slice(ref element) => element.def(),
SpirvType::Adt { def, .. } => def,
SpirvType::Pointer { def, .. } => def,
}
}
}
*/

pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word {
if ty.is_zst() {
return cx.emit_global().type_struct(&[]);
let def = SpirvType::ZST;
let result = cx.emit_global().type_struct(&[]);
cx.def_type(result, def);
return result;
}

// Note: ty.abi is orthogonal to ty.variants and ty.fields, e.g. `ManuallyDrop<Result<isize, isize>>`
Expand All @@ -58,30 +45,53 @@ pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>)
Abi::ScalarPair(ref one, ref two) => {
let one_spirv = trans_scalar(cx, one);
let two_spirv = trans_scalar(cx, two);
cx.emit_global().type_struct([one_spirv, two_spirv])
let result = cx.emit_global().type_struct([one_spirv, two_spirv]);
let def = SpirvType::Adt {
field_types: vec![one_spirv, two_spirv],
};
cx.def_type(result, def);
result
}
Abi::Vector { ref element, count } => {
let elem_spirv = trans_scalar(cx, element);
cx.emit_global().type_vector(elem_spirv, count as u32)
let result = cx.emit_global().type_vector(elem_spirv, count as u32);
let def = SpirvType::Vector {
element: elem_spirv,
count: count as u32,
};
cx.def_type(result, def);
result
}
Abi::Aggregate { sized: _ } => trans_aggregate(cx, ty),
}
}

fn trans_scalar<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, scalar: &Scalar) -> Word {
match scalar.value {
Primitive::Int(width, signedness) => cx
.emit_global()
.type_int(width.size().bits() as u32, if signedness { 1 } else { 0 }),
Primitive::F32 => cx.emit_global().type_float(32),
Primitive::F64 => cx.emit_global().type_float(64),
let (ty, def) = match scalar.value {
Primitive::Int(width, signedness) => {
if scalar.valid_range == (0..=1) {
(cx.emit_global().type_bool(), SpirvType::Bool)
} else if scalar.valid_range != (0..=((1 << (width.size().bits() as u128)) - 1)) {
panic!("TODO: Unimplemented valid_range that's not the size of the int (width={:?}, range={:?}): {:?}", width, scalar.valid_range, scalar)
} else {
(
cx.emit_global()
.type_int(width.size().bits() as u32, if signedness { 1 } else { 0 }),
SpirvType::Integer(width as u32, signedness),
)
}
}
Primitive::F32 => (cx.emit_global().type_float(32), SpirvType::Float(32)),
Primitive::F64 => (cx.emit_global().type_float(64), SpirvType::Float(64)),
Primitive::Pointer => {
panic!(
"TODO: Scalar(Pointer) not supported yet in trans_type: {:?}",
scalar
);
}
}
};
cx.def_type(ty, def);
ty
}

fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word {
Expand All @@ -97,7 +107,13 @@ fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>
FieldsShape::Array { stride: _, count } => {
// TODO: Assert stride is same as spirv's stride?
let element_type = trans_type(cx, ty.field(cx, 0));
cx.emit_global().type_array(element_type, count as u32)
let result = cx.emit_global().type_array(element_type, count as u32);
let def = SpirvType::Array {
element: element_type,
count: count as u32,
};
cx.def_type(result, def);
result
}
FieldsShape::Arbitrary {
offsets: _,
Expand Down Expand Up @@ -134,5 +150,10 @@ fn trans_struct<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -
offset = target_offset + field.size;
prev_effective_align = effective_field_align;
}
cx.emit_global().type_struct(&result)
let result_ty = cx.emit_global().type_struct(&result);
let def = SpirvType::Adt {
field_types: result,
};
cx.def_type(result_ty, def);
result_ty
}
12 changes: 8 additions & 4 deletions rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::Builder;
use crate::builder_spirv::SpirvValueExt;
use rustc_codegen_ssa::common::{
AtomicOrdering, AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope,
};
Expand Down Expand Up @@ -62,7 +63,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
}

fn ret(&mut self, value: Self::Value) {
self.emit().ret_value(value).unwrap();
self.emit().ret_value(value.def).unwrap();
}

fn br(&mut self, _dest: Self::BasicBlock) {
Expand Down Expand Up @@ -219,9 +220,12 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
}

fn or(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
// TODO: implement result_type
let result_type = Default::default();
self.emit().bitwise_or(result_type, None, lhs, rhs).unwrap()
assert_eq!(lhs.ty, rhs.ty);
let result_type = lhs.ty;
self.emit()
.bitwise_or(result_type, None, lhs.def, rhs.def)
.unwrap()
.with_type(result_type)
}

fn xor(&mut self, _lhs: Self::Value, _rhs: Self::Value) -> Self::Value {
Expand Down
16 changes: 16 additions & 0 deletions rustc_codegen_spirv/src/builder_spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ impl ModuleSpirv {
}
}

#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
pub struct SpirvValue {
pub def: Word,
pub ty: Word,
}

pub trait SpirvValueExt {
fn with_type(self, ty: Word) -> SpirvValue;
}

impl SpirvValueExt for Word {
fn with_type(self, ty: Word) -> SpirvValue {
SpirvValue { def: self, ty }
}
}

#[derive(Default, Copy, Clone)]
#[must_use = "BuilderCursor should usually be assigned to the Builder.cursor field"]
pub struct BuilderCursor {
Expand Down
68 changes: 53 additions & 15 deletions rustc_codegen_spirv/src/codegen_cx.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, ModuleSpirv};
use crate::abi::SpirvType;
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, ModuleSpirv, SpirvValue, SpirvValueExt};
use rspirv::spirv::{FunctionControl, Word};
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind};
Expand All @@ -21,10 +22,9 @@ use rustc_span::def_id::{CrateNum, DefId};
use rustc_span::source_map::{Span, DUMMY_SP};
use rustc_span::symbol::Symbol;
use rustc_span::SourceFile;
use rustc_target::abi;
use rustc_target::abi::call::{CastTarget, FnAbi, Reg};
use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg};
use rustc_target::abi::{
Abi, AddressSpace, Align, HasDataLayout, LayoutOf, Primitive, Size, TargetDataLayout,
self, Abi, AddressSpace, Align, HasDataLayout, LayoutOf, Primitive, Size, TargetDataLayout,
};
use rustc_target::spec::{HasTargetSpec, Target};
use std::cell::RefCell;
Expand All @@ -36,7 +36,8 @@ pub struct CodegenCx<'spv, 'tcx> {
pub spirv_module: &'spv ModuleSpirv,
pub builder: BuilderSpirv,
pub function_defs: RefCell<HashMap<Instance<'tcx>, Word>>,
pub function_parameter_values: RefCell<HashMap<Word, Vec<Word>>>,
pub function_parameter_values: RefCell<HashMap<Word, Vec<SpirvValue>>>,
pub type_defs: RefCell<HashMap<Word, SpirvType>>,
}

impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
Expand All @@ -52,6 +53,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
builder: BuilderSpirv::new(),
function_defs: RefCell::new(HashMap::new()),
function_parameter_values: RefCell::new(HashMap::new()),
type_defs: RefCell::new(HashMap::new()),
}
}

Expand All @@ -74,6 +76,18 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
trans_type(self, ty)
}

pub fn def_type(&self, ty: Word, def: SpirvType) {
self.type_defs.borrow_mut().insert(ty, def);
}

pub fn lookup_type(&self, ty: Word) -> SpirvType {
self.type_defs
.borrow()
.get(&ty)
.expect("Tried to lookup value that wasn't a type, or has no definition")
.clone()
}

pub fn finalize_module(self) {
let result = self.builder.finalize();
let mut output = self.spirv_module.module.lock().unwrap();
Expand All @@ -85,7 +99,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
}

impl<'spv, 'tcx> BackendTypes for CodegenCx<'spv, 'tcx> {
type Value = Word;
type Value = SpirvValue;
type Function = Word;

type BasicBlock = Word;
Expand Down Expand Up @@ -346,14 +360,30 @@ impl<'spv, 'tcx> PreDefineMethods<'tcx> for CodegenCx<'spv, 'tcx> {
_visibility: Visibility,
_symbol_name: &str,
) {
let mut emit = self.emit_global();
fn assert_mode(mode: PassMode) {
if let PassMode::Direct(_) = mode {
} else {
panic!("PassMode not supported yet: {:?}", mode)
}
}
let fn_abi = FnAbi::of_instance(self, instance, &[]);
let argument_types = fn_abi
.args
.iter()
.map(|arg| self.trans_type(arg.layout))
.map(|arg| {
assert_mode(arg.mode);
self.trans_type(arg.layout)
})
.collect::<Vec<_>>();
let return_type = self.trans_type(fn_abi.ret.layout);
// TODO: Do we register types created here in the type tracker?
// TODO: Other modes
let return_type = if fn_abi.ret.mode == PassMode::Ignore {
self.emit_global().type_void()
} else {
assert_mode(fn_abi.ret.mode);
self.trans_type(fn_abi.ret.layout)
};
let mut emit = self.emit_global();
let control = FunctionControl::NONE;
let function_id = None;
let function_type = emit.type_function(return_type, &argument_types);
Expand All @@ -363,7 +393,7 @@ impl<'spv, 'tcx> PreDefineMethods<'tcx> for CodegenCx<'spv, 'tcx> {
.unwrap();
let parameter_values = argument_types
.iter()
.map(|&ty| emit.function_parameter(ty).unwrap())
.map(|&ty| emit.function_parameter(ty).unwrap().with_type(ty))
.collect::<Vec<_>>();
emit.end_function().unwrap();

Expand Down Expand Up @@ -457,7 +487,7 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> {
todo!()
}
fn const_undef(&self, ty: Self::Type) -> Self::Value {
self.emit_global().undef(ty, None)
self.emit_global().undef(ty, None).with_type(ty)
}
fn const_int(&self, _t: Self::Type, _i: i64) -> Self::Value {
todo!()
Expand Down Expand Up @@ -514,19 +544,27 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> {
match scalar {
Scalar::Raw { data, size } => match layout.value {
Primitive::Int(_size, _signedness) => match size {
4 => self.emit_global().constant_u32(ty, data as u32),
8 => self.emit_global().constant_u64(ty, data as u64),
4 => self
.emit_global()
.constant_u32(ty, data as u32)
.with_type(ty),
8 => self
.emit_global()
.constant_u64(ty, data as u64)
.with_type(ty),
size => panic!(
"TODO: scalar_to_backend int size {} not implemented yet",
size
),
},
Primitive::F32 => self
.emit_global()
.constant_f32(ty, f32::from_bits(data as u32)),
.constant_f32(ty, f32::from_bits(data as u32))
.with_type(ty),
Primitive::F64 => self
.emit_global()
.constant_f64(ty, f64::from_bits(data as u64)),
.constant_f64(ty, f64::from_bits(data as u64))
.with_type(ty),
Primitive::Pointer => {
panic!("TODO: scalar_to_backend Primitive::Ptr not implemented yet")
}
Expand Down
1 change: 0 additions & 1 deletion rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![feature(rustc_private)]
#![feature(or_insert_with_key)]

extern crate rustc_ast;
extern crate rustc_codegen_ssa;
Expand Down

0 comments on commit db97a8e

Please sign in to comment.