Skip to content

Commit

Permalink
Implement functions with multiple return values. (move-language#105)
Browse files Browse the repository at this point in the history
Functions in Move use a second-class tuple-like expression to bind, return, and destructure multiple values.

On exit from a function, we generate LLVM IR to wrap them up into a struct, which is returned as a single IR value. Similarly, when a callee that returns such a value is used in an expression, we generate IR to extract each actual value from the struct.

Also deduplicated load_call and load_call_store, as the former is just and instance the latter with no return values passed.

Added a move-ir-test and a runtime rbpf test to cover above.
  • Loading branch information
nvjle authored Apr 26, 2023
1 parent 699db66 commit 18dd944
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 54 deletions.
93 changes: 63 additions & 30 deletions language/tools/move-mv-llvm-compiler/src/stackless/llvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ impl Context {
unsafe { Builder(LLVMCreateBuilderInContext(self.0)) }
}

pub fn get_anonymous_struct_type(&self, field_tys: &[Type]) -> Type {
unsafe {
let mut field_tys: Vec<_> = field_tys.iter().map(|f| f.0).collect();
Type(LLVMStructTypeInContext(
self.0,
field_tys.as_mut_ptr(),
field_tys.len() as u32,
0, /* !packed */
))
}
}

pub fn void_type(&self) -> Type {
unsafe { Type(LLVMVoidTypeInContext(self.0)) }
}
Expand Down Expand Up @@ -266,6 +278,26 @@ impl Builder {
}
}

pub fn load_multi_return(&self, return_ty: Type, vals: &[(Type, Alloca)]) {
unsafe {
let loads = vals
.iter()
.enumerate()
.map(|(i, (ty, val))| {
let name = format!("rv.{i}");
LLVMBuildLoad2(self.0, ty.0, val.0, name.cstr())
})
.collect::<Vec<_>>();

let mut agg_val = LLVMGetUndef(return_ty.0);
for i in 0..loads.len() {
let s = format!("insert_{i}").cstr();
agg_val = LLVMBuildInsertValue(self.0, agg_val, loads[i], i as libc::c_uint, s);
}
LLVMBuildRet(self.0, agg_val);
}
}

pub fn store_const(&self, src: Constant, dst: Alloca) {
unsafe {
LLVMBuildStore(self.0, src.0, dst.0);
Expand Down Expand Up @@ -312,13 +344,11 @@ impl Builder {

let mut tys = types
.iter()
.enumerate()
.map(|(_i, ty)| ty.0)
.map(|ty| ty.0)
.collect::<Vec<_>>();
let mut args = args
.iter()
.enumerate()
.map(|(_i, val)| *val)
.map(|val| *val)
.collect::<Vec<_>>();

unsafe {
Expand All @@ -339,30 +369,7 @@ impl Builder {
}
}

pub fn load_call(&self, fnval: Function, args: &[(Type, Alloca)]) {
let fnty = fnval.llvm_type();

unsafe {
let mut args = args
.iter()
.enumerate()
.map(|(i, (ty, val))| {
let name = format!("call_arg_{i}");
LLVMBuildLoad2(self.0, ty.0, val.0, name.cstr())
})
.collect::<Vec<_>>();
LLVMBuildCall2(
self.0,
fnty.0,
fnval.0,
args.as_mut_ptr(),
args.len() as libc::c_uint,
"".cstr(),
);
}
}

pub fn load_call_store(&self, fnval: Function, args: &[(Type, Alloca)], dst: (Type, Alloca)) {
pub fn load_call_store(&self, fnval: Function, args: &[(Type, Alloca)], dst: &[(Type, Alloca)]) {
let fnty = fnval.llvm_type();

unsafe {
Expand All @@ -380,10 +387,30 @@ impl Builder {
fnval.0,
args.as_mut_ptr(),
args.len() as libc::c_uint,
"retval".cstr(),
(if dst.len() == 0 { "" } else { "retval" }).cstr(),
);

LLVMBuildStore(self.0, ret, dst.1 .0);
if dst.len() == 0 {
// No return values.
return;
} else if dst.len() == 1 {
// Single return value.
LLVMBuildStore(self.0, ret, dst[0].1.0);
} else {
// Multiple return values-- unwrap the struct.
let extracts = dst
.iter()
.enumerate()
.map(|(i, (_ty, dval))| {
let name = format!("extract_{i}");
let ev = LLVMBuildExtractValue(self.0, ret, i as libc::c_uint, name.cstr());
(ev, dval)
})
.collect::<Vec<_>>();
for (ev, dval) in extracts {
LLVMBuildStore(self.0, ev, dval.0);
}
}
}
}

Expand Down Expand Up @@ -531,6 +558,12 @@ impl Function {
unsafe { FunctionType(LLVMGlobalGetValueType(self.0)) }
}

pub fn llvm_return_type(&self) -> Type {
unsafe {
Type(LLVMGetReturnType(LLVMGlobalGetValueType(self.0)))
}
}

pub fn verify(&self) {
use llvm_sys::analysis::*;
unsafe {
Expand Down
50 changes: 26 additions & 24 deletions language/tools/move-mv-llvm-compiler/src/stackless/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ impl<'mm, 'up> ModuleContext<'mm, 'up> {
0 => self.llvm_cx.void_type(),
1 => self.llvm_type(&fn_data.return_types[0]),
_ => {
todo!()
// Wrap multiple return values in a struct.
let tys: Vec<_> =
fn_data.return_types.iter().map(|f| self.llvm_type(f)).collect();
let rty = self.llvm_cx.get_anonymous_struct_type(&tys);
rty
}
};

Expand Down Expand Up @@ -458,7 +462,17 @@ impl<'mm, 'up> FunctionContext<'mm, 'up> {
let llty = self.locals[idx].llty;
self.llvm_builder.load_return(llty, llval);
}
_ => todo!(),
_ => {
// Multiple return values are wrapped in a struct.
let nvals = vals
.iter()
.map(|i| (self.locals[*i].llty, self.locals[*i].llval))
.collect::<Vec<_>>();

let ll_fn = &self.fn_decls[&self.env.get_qualified_id()];
let ret_ty = ll_fn.llvm_return_type();
self.llvm_builder.load_multi_return(ret_ty, &nvals);
},
},
sbc::Bytecode::Load(_, idx, val) => {
let local_llval = self.locals[*idx].llval;
Expand Down Expand Up @@ -968,29 +982,17 @@ impl<'mm, 'up> FunctionContext<'mm, 'up> {

let ll_fn = self.fn_decls[&fun_id.qualified(mod_id)];

if dst_locals.len() > 1 {
todo!()
}
let src = src_locals
.iter()
.map(|l| (l.llty, l.llval))
.collect::<Vec<_>>();

let dst = dst_locals.get(0);
let dst = dst_locals
.iter()
.map(|l| (l.llty, l.llval))
.collect::<Vec<_>>();

match dst {
None => {
let src = src_locals
.iter()
.map(|l| (l.llty, l.llval))
.collect::<Vec<_>>();
self.llvm_builder.load_call(ll_fn, &src);
}
Some(dst) => {
let dst = (dst.llty, dst.llval);
let src = src_locals
.iter()
.map(|l| (l.llty, l.llval))
.collect::<Vec<_>>();
self.llvm_builder.load_call_store(ll_fn, &src, dst);
}
}
self.llvm_builder.load_call_store(ll_fn, &src, &dst);
}

fn constant(&self, mc: &sbc::Constant) -> llvm::Constant {
Expand Down Expand Up @@ -1037,7 +1039,7 @@ impl<'mm, 'up> FunctionContext<'mm, 'up> {
let local_llval = self.locals[*local_idx].llval;
let local_llty = self.locals[*local_idx].llty;
self.llvm_builder
.load_call(llfn, &[(local_llty, local_llval)]);
.load_call_store(llfn, &[(local_llty, local_llval)], &[]);
self.llvm_builder.build_unreachable();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
; ModuleID = '0x100__Test'
source_filename = "<unknown>"

define { i1, i1 } @Test__ret_2vals() {
entry:
%local_0 = alloca i1, align 1
%local_1 = alloca i1, align 1
store i1 true, ptr %local_0, align 1
store i1 false, ptr %local_1, align 1
%rv.0 = load i1, ptr %local_0, align 1
%rv.1 = load i1, ptr %local_1, align 1
%insert_0 = insertvalue { i1, i1 } undef, i1 %rv.0, 0
%insert_1 = insertvalue { i1, i1 } %insert_0, i1 %rv.1, 1
ret { i1, i1 } %insert_1
}

define { ptr, i8, i128, i32 } @Test__ret_4vals(ptr %0) {
entry:
%local_0 = alloca ptr, align 8
%local_1 = alloca ptr, align 8
%local_2 = alloca i8, align 1
%local_3 = alloca i128, align 8
%local_4 = alloca i32, align 4
store ptr %0, ptr %local_0, align 8
%load_store_tmp = load ptr, ptr %local_0, align 8
store ptr %load_store_tmp, ptr %local_1, align 8
store i8 8, ptr %local_2, align 1
store i128 128, ptr %local_3, align 4
store i32 32, ptr %local_4, align 4
%rv.0 = load ptr, ptr %local_1, align 8
%rv.1 = load i8, ptr %local_2, align 1
%rv.2 = load i128, ptr %local_3, align 4
%rv.3 = load i32, ptr %local_4, align 4
%insert_0 = insertvalue { ptr, i8, i128, i32 } undef, ptr %rv.0, 0
%insert_1 = insertvalue { ptr, i8, i128, i32 } %insert_0, i8 %rv.1, 1
%insert_2 = insertvalue { ptr, i8, i128, i32 } %insert_1, i128 %rv.2, 2
%insert_3 = insertvalue { ptr, i8, i128, i32 } %insert_2, i32 %rv.3, 3
ret { ptr, i8, i128, i32 } %insert_3
}

define void @Test__use_2val_call_result() {
entry:
%local_0 = alloca i1, align 1
%local_1 = alloca i1, align 1
%local_2 = alloca i1, align 1
%retval = call { i1, i1 } @Test__ret_2vals()
%extract_0 = extractvalue { i1, i1 } %retval, 0
%extract_1 = extractvalue { i1, i1 } %retval, 1
store i1 %extract_0, ptr %local_0, align 1
store i1 %extract_1, ptr %local_1, align 1
%or_src_0 = load i1, ptr %local_0, align 1
%or_src_1 = load i1, ptr %local_1, align 1
%or_dst = or i1 %or_src_0, %or_src_1
store i1 %or_dst, ptr %local_2, align 1
ret void
}

define void @Test__use_4val_call_result() {
entry:
%local_0 = alloca i64, align 8
%local_1 = alloca i8, align 1
%local_2 = alloca i128, align 8
%local_3 = alloca i32, align 4
%local_4 = alloca i64, align 8
%local_5 = alloca ptr, align 8
%local_6 = alloca ptr, align 8
%local_7 = alloca i8, align 1
%local_8 = alloca i128, align 8
%local_9 = alloca i32, align 4
%local_10 = alloca i64, align 8
%local_11 = alloca i8, align 1
%local_12 = alloca i128, align 8
%local_13 = alloca i32, align 4
store i64 0, ptr %local_4, align 4
%load_store_tmp = load i64, ptr %local_4, align 4
store i64 %load_store_tmp, ptr %local_0, align 4
store ptr %local_0, ptr %local_5, align 8
%call_arg_0 = load ptr, ptr %local_5, align 8
%retval = call { ptr, i8, i128, i32 } @Test__ret_4vals(ptr %call_arg_0)
%extract_0 = extractvalue { ptr, i8, i128, i32 } %retval, 0
%extract_1 = extractvalue { ptr, i8, i128, i32 } %retval, 1
%extract_2 = extractvalue { ptr, i8, i128, i32 } %retval, 2
%extract_3 = extractvalue { ptr, i8, i128, i32 } %retval, 3
store ptr %extract_0, ptr %local_6, align 8
store i8 %extract_1, ptr %local_7, align 1
store i128 %extract_2, ptr %local_8, align 4
store i32 %extract_3, ptr %local_9, align 4
%load_store_tmp1 = load i32, ptr %local_9, align 4
store i32 %load_store_tmp1, ptr %local_3, align 4
%load_store_tmp2 = load i128, ptr %local_8, align 4
store i128 %load_store_tmp2, ptr %local_2, align 4
%load_store_tmp3 = load i8, ptr %local_7, align 1
store i8 %load_store_tmp3, ptr %local_1, align 1
%load_deref_store_tmp1 = load ptr, ptr %local_6, align 8
%load_deref_store_tmp2 = load i64, ptr %load_deref_store_tmp1, align 4
store i64 %load_deref_store_tmp2, ptr %local_10, align 4
%load_store_tmp4 = load i8, ptr %local_1, align 1
store i8 %load_store_tmp4, ptr %local_11, align 1
%load_store_tmp5 = load i128, ptr %local_2, align 4
store i128 %load_store_tmp5, ptr %local_12, align 4
%load_store_tmp6 = load i32, ptr %local_3, align 4
store i32 %load_store_tmp6, ptr %local_13, align 4
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module 0x100::Test {
fun ret_2vals(): (bool, bool) { (true, false) }
fun ret_4vals(x: &u64): (&u64, u8, u128, u32) { (x, 8, 128, 32) }

fun use_2val_call_result() {
let (x, y): (bool, bool) = ret_2vals();
let _t = x || y;
}
fun use_4val_call_result() {
let (a, b, c, d) = ret_4vals(&0);
let _t1 = *a;
let _t2 = b;
let _t3 = c;
let _t4 = d;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Multiple return value similar to an example from the Move Book.

module 0x1::Math {
public fun max(a: u8, b: u8): (u8, bool) {
if (a > b) {
(a, false)
} else if (a < b) {
(b, false)
} else {
(a, true)
}
}
}

script {
use 0x1::Math;

fun main() {
let (maxval, is_equal) = Math::max(99, 100);
assert!(maxval == 100, 0xf00);
assert!(!is_equal, 0xf01);

let (maxval, is_equal) = Math::max(5, 0);
assert!(maxval == 5, 0xf02);
assert!(!is_equal, 0xf03);

let (maxval, is_equal) = Math::max(123, 123);
assert!(maxval == 123, 0xf04);
assert!(is_equal, 0xf05);
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module 0x100::Test {
public fun ret_6vals(a: u8, b: u16, c: u32, d: u64, e: u128, f: u256): (u8, u16, u32, u64, u128, u256) {
(a, b, c, d, e, f)
}
}

script {
fun main() {
let (x1, x2, x3, x4, x5, x6) = 0x100::Test::ret_6vals(1, 2, 3, 4, 5, 6);
assert!(x1 == 1, 0xf00);
assert!(x2 == 2, 0xf01);
assert!(x3 == 3, 0xf02);
assert!(x4 == 4, 0xf03);
assert!(x5 == 5, 0xf04);
assert!(x6 == 6, 0xf05);
}
}

0 comments on commit 18dd944

Please sign in to comment.