diff --git a/rustc_codegen_spirv/src/abi.rs b/rustc_codegen_spirv/src/abi.rs index 923d84ad38..d2c1bb60f1 100644 --- a/rustc_codegen_spirv/src/abi.rs +++ b/rustc_codegen_spirv/src/abi.rs @@ -3,48 +3,35 @@ use rspirv::spirv::Word; use rustc_middle::ty::layout::TyAndLayout; use rustc_target::abi::{Abi, FieldsShape, Primitive, Scalar, Size}; -/* #[derive(Clone, Debug, PartialEq, Eq)] pub enum SpirvType { - Bool(Word), - Integer(Word, u32, bool), - Float(Word, u32), + Bool, + Integer(u32, bool), + Float(u32), // TODO: Do we fold this into Adt? /// Zero Sized Type - ZST(Word), - /// This variant is kind of useless, but it lets us recognize Pointer(Slice(T)), etc. - // TODO: Actually recognize Pointer(Slice(T)) and generate a wide pointer. - Slice(Box), + ZST, /// This uses the rustc definition of "adt", i.e. a struct, enum, or union Adt { - def: Word, // TODO: enums/unions - field_types: Vec, + field_types: Vec, }, - Pointer { - def: Word, - pointee: Box, + Vector { + element: Word, + count: u32, + }, + Array { + element: Word, + count: u32, }, } -impl SpirvType { - pub fn def(&self) -> Word { - match *self { - SpirvType::Bool(def) => def, - SpirvType::Integer(def, _, _) => def, - SpirvType::Float(def, _) => def, - SpirvType::ZST(def) => def, - SpirvType::Slice(ref element) => element.def(), - SpirvType::Adt { def, .. } => def, - SpirvType::Pointer { def, .. } => def, - } - } -} -*/ - pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word { if ty.is_zst() { - return cx.emit_global().type_struct(&[]); + let def = SpirvType::ZST; + let result = cx.emit_global().type_struct(&[]); + cx.def_type(result, def); + return result; } // Note: ty.abi is orthogonal to ty.variants and ty.fields, e.g. `ManuallyDrop>` @@ -58,30 +45,53 @@ pub fn trans_type<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) Abi::ScalarPair(ref one, ref two) => { let one_spirv = trans_scalar(cx, one); let two_spirv = trans_scalar(cx, two); - cx.emit_global().type_struct([one_spirv, two_spirv]) + let result = cx.emit_global().type_struct([one_spirv, two_spirv]); + let def = SpirvType::Adt { + field_types: vec![one_spirv, two_spirv], + }; + cx.def_type(result, def); + result } Abi::Vector { ref element, count } => { let elem_spirv = trans_scalar(cx, element); - cx.emit_global().type_vector(elem_spirv, count as u32) + let result = cx.emit_global().type_vector(elem_spirv, count as u32); + let def = SpirvType::Vector { + element: elem_spirv, + count: count as u32, + }; + cx.def_type(result, def); + result } Abi::Aggregate { sized: _ } => trans_aggregate(cx, ty), } } fn trans_scalar<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, scalar: &Scalar) -> Word { - match scalar.value { - Primitive::Int(width, signedness) => cx - .emit_global() - .type_int(width.size().bits() as u32, if signedness { 1 } else { 0 }), - Primitive::F32 => cx.emit_global().type_float(32), - Primitive::F64 => cx.emit_global().type_float(64), + let (ty, def) = match scalar.value { + Primitive::Int(width, signedness) => { + if scalar.valid_range == (0..=1) { + (cx.emit_global().type_bool(), SpirvType::Bool) + } else if scalar.valid_range != (0..=((1 << (width.size().bits() as u128)) - 1)) { + panic!("TODO: Unimplemented valid_range that's not the size of the int (width={:?}, range={:?}): {:?}", width, scalar.valid_range, scalar) + } else { + ( + cx.emit_global() + .type_int(width.size().bits() as u32, if signedness { 1 } else { 0 }), + SpirvType::Integer(width as u32, signedness), + ) + } + } + Primitive::F32 => (cx.emit_global().type_float(32), SpirvType::Float(32)), + Primitive::F64 => (cx.emit_global().type_float(64), SpirvType::Float(64)), Primitive::Pointer => { panic!( "TODO: Scalar(Pointer) not supported yet in trans_type: {:?}", scalar ); } - } + }; + cx.def_type(ty, def); + ty } fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) -> Word { @@ -97,7 +107,13 @@ fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx> FieldsShape::Array { stride: _, count } => { // TODO: Assert stride is same as spirv's stride? let element_type = trans_type(cx, ty.field(cx, 0)); - cx.emit_global().type_array(element_type, count as u32) + let result = cx.emit_global().type_array(element_type, count as u32); + let def = SpirvType::Array { + element: element_type, + count: count as u32, + }; + cx.def_type(result, def); + result } FieldsShape::Arbitrary { offsets: _, @@ -134,5 +150,10 @@ fn trans_struct<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>) - offset = target_offset + field.size; prev_effective_align = effective_field_align; } - cx.emit_global().type_struct(&result) + let result_ty = cx.emit_global().type_struct(&result); + let def = SpirvType::Adt { + field_types: result, + }; + cx.def_type(result_ty, def); + result_ty } diff --git a/rustc_codegen_spirv/src/builder/builder_methods.rs b/rustc_codegen_spirv/src/builder/builder_methods.rs index 0e2373b1a5..b55c9af8ac 100644 --- a/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -1,4 +1,5 @@ use super::Builder; +use crate::builder_spirv::SpirvValueExt; use rustc_codegen_ssa::common::{ AtomicOrdering, AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope, }; @@ -62,7 +63,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { } fn ret(&mut self, value: Self::Value) { - self.emit().ret_value(value).unwrap(); + self.emit().ret_value(value.def).unwrap(); } fn br(&mut self, _dest: Self::BasicBlock) { @@ -219,9 +220,12 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> { } fn or(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value { - // TODO: implement result_type - let result_type = Default::default(); - self.emit().bitwise_or(result_type, None, lhs, rhs).unwrap() + assert_eq!(lhs.ty, rhs.ty); + let result_type = lhs.ty; + self.emit() + .bitwise_or(result_type, None, lhs.def, rhs.def) + .unwrap() + .with_type(result_type) } fn xor(&mut self, _lhs: Self::Value, _rhs: Self::Value) -> Self::Value { diff --git a/rustc_codegen_spirv/src/builder_spirv.rs b/rustc_codegen_spirv/src/builder_spirv.rs index 31ba25c39a..f103a357a8 100644 --- a/rustc_codegen_spirv/src/builder_spirv.rs +++ b/rustc_codegen_spirv/src/builder_spirv.rs @@ -19,6 +19,22 @@ impl ModuleSpirv { } } +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)] +pub struct SpirvValue { + pub def: Word, + pub ty: Word, +} + +pub trait SpirvValueExt { + fn with_type(self, ty: Word) -> SpirvValue; +} + +impl SpirvValueExt for Word { + fn with_type(self, ty: Word) -> SpirvValue { + SpirvValue { def: self, ty } + } +} + #[derive(Default, Copy, Clone)] #[must_use = "BuilderCursor should usually be assigned to the Builder.cursor field"] pub struct BuilderCursor { diff --git a/rustc_codegen_spirv/src/codegen_cx.rs b/rustc_codegen_spirv/src/codegen_cx.rs index 689619c674..25ca90d7cf 100644 --- a/rustc_codegen_spirv/src/codegen_cx.rs +++ b/rustc_codegen_spirv/src/codegen_cx.rs @@ -1,4 +1,5 @@ -use crate::builder_spirv::{BuilderCursor, BuilderSpirv, ModuleSpirv}; +use crate::abi::SpirvType; +use crate::builder_spirv::{BuilderCursor, BuilderSpirv, ModuleSpirv, SpirvValue, SpirvValueExt}; use rspirv::spirv::{FunctionControl, Word}; use rustc_codegen_ssa::common::TypeKind; use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind}; @@ -21,10 +22,9 @@ use rustc_span::def_id::{CrateNum, DefId}; use rustc_span::source_map::{Span, DUMMY_SP}; use rustc_span::symbol::Symbol; use rustc_span::SourceFile; -use rustc_target::abi; -use rustc_target::abi::call::{CastTarget, FnAbi, Reg}; +use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg}; use rustc_target::abi::{ - Abi, AddressSpace, Align, HasDataLayout, LayoutOf, Primitive, Size, TargetDataLayout, + self, Abi, AddressSpace, Align, HasDataLayout, LayoutOf, Primitive, Size, TargetDataLayout, }; use rustc_target::spec::{HasTargetSpec, Target}; use std::cell::RefCell; @@ -36,7 +36,8 @@ pub struct CodegenCx<'spv, 'tcx> { pub spirv_module: &'spv ModuleSpirv, pub builder: BuilderSpirv, pub function_defs: RefCell, Word>>, - pub function_parameter_values: RefCell>>, + pub function_parameter_values: RefCell>>, + pub type_defs: RefCell>, } impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> { @@ -52,6 +53,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> { builder: BuilderSpirv::new(), function_defs: RefCell::new(HashMap::new()), function_parameter_values: RefCell::new(HashMap::new()), + type_defs: RefCell::new(HashMap::new()), } } @@ -74,6 +76,18 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> { trans_type(self, ty) } + pub fn def_type(&self, ty: Word, def: SpirvType) { + self.type_defs.borrow_mut().insert(ty, def); + } + + pub fn lookup_type(&self, ty: Word) -> SpirvType { + self.type_defs + .borrow() + .get(&ty) + .expect("Tried to lookup value that wasn't a type, or has no definition") + .clone() + } + pub fn finalize_module(self) { let result = self.builder.finalize(); let mut output = self.spirv_module.module.lock().unwrap(); @@ -85,7 +99,7 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> { } impl<'spv, 'tcx> BackendTypes for CodegenCx<'spv, 'tcx> { - type Value = Word; + type Value = SpirvValue; type Function = Word; type BasicBlock = Word; @@ -346,14 +360,30 @@ impl<'spv, 'tcx> PreDefineMethods<'tcx> for CodegenCx<'spv, 'tcx> { _visibility: Visibility, _symbol_name: &str, ) { - let mut emit = self.emit_global(); + fn assert_mode(mode: PassMode) { + if let PassMode::Direct(_) = mode { + } else { + panic!("PassMode not supported yet: {:?}", mode) + } + } let fn_abi = FnAbi::of_instance(self, instance, &[]); let argument_types = fn_abi .args .iter() - .map(|arg| self.trans_type(arg.layout)) + .map(|arg| { + assert_mode(arg.mode); + self.trans_type(arg.layout) + }) .collect::>(); - let return_type = self.trans_type(fn_abi.ret.layout); + // TODO: Do we register types created here in the type tracker? + // TODO: Other modes + let return_type = if fn_abi.ret.mode == PassMode::Ignore { + self.emit_global().type_void() + } else { + assert_mode(fn_abi.ret.mode); + self.trans_type(fn_abi.ret.layout) + }; + let mut emit = self.emit_global(); let control = FunctionControl::NONE; let function_id = None; let function_type = emit.type_function(return_type, &argument_types); @@ -363,7 +393,7 @@ impl<'spv, 'tcx> PreDefineMethods<'tcx> for CodegenCx<'spv, 'tcx> { .unwrap(); let parameter_values = argument_types .iter() - .map(|&ty| emit.function_parameter(ty).unwrap()) + .map(|&ty| emit.function_parameter(ty).unwrap().with_type(ty)) .collect::>(); emit.end_function().unwrap(); @@ -457,7 +487,7 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> { todo!() } fn const_undef(&self, ty: Self::Type) -> Self::Value { - self.emit_global().undef(ty, None) + self.emit_global().undef(ty, None).with_type(ty) } fn const_int(&self, _t: Self::Type, _i: i64) -> Self::Value { todo!() @@ -514,8 +544,14 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> { match scalar { Scalar::Raw { data, size } => match layout.value { Primitive::Int(_size, _signedness) => match size { - 4 => self.emit_global().constant_u32(ty, data as u32), - 8 => self.emit_global().constant_u64(ty, data as u64), + 4 => self + .emit_global() + .constant_u32(ty, data as u32) + .with_type(ty), + 8 => self + .emit_global() + .constant_u64(ty, data as u64) + .with_type(ty), size => panic!( "TODO: scalar_to_backend int size {} not implemented yet", size @@ -523,10 +559,12 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> { }, Primitive::F32 => self .emit_global() - .constant_f32(ty, f32::from_bits(data as u32)), + .constant_f32(ty, f32::from_bits(data as u32)) + .with_type(ty), Primitive::F64 => self .emit_global() - .constant_f64(ty, f64::from_bits(data as u64)), + .constant_f64(ty, f64::from_bits(data as u64)) + .with_type(ty), Primitive::Pointer => { panic!("TODO: scalar_to_backend Primitive::Ptr not implemented yet") } diff --git a/rustc_codegen_spirv/src/lib.rs b/rustc_codegen_spirv/src/lib.rs index 9bdcd90aae..02d0d9025d 100644 --- a/rustc_codegen_spirv/src/lib.rs +++ b/rustc_codegen_spirv/src/lib.rs @@ -1,5 +1,4 @@ #![feature(rustc_private)] -#![feature(or_insert_with_key)] extern crate rustc_ast; extern crate rustc_codegen_ssa;