Skip to content

Commit

Permalink
Fix generic methods returning references with where clauses
Browse files Browse the repository at this point in the history
  • Loading branch information
asomers committed Aug 9, 2020
1 parent 2ff9e31 commit 57e5304
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
### Added
### Changed
### Fixed
- Fixed mocking generic methods with where clauses returning references.
([#166](https://github.com/asomers/mockall/pull/166))

- Fixed mocking generic methods returning mutable references.
([#165](https://github.com/asomers/mockall/pull/165))

Expand Down
19 changes: 19 additions & 0 deletions mockall/tests/mock_generic_method_with_where_clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ mock! {
Foo {
fn foo<T>(&self, t: T) -> G<T> where T: Copy + 'static;
fn bar<T>(&self, g: G<T>) -> T where T: Copy + 'static;
fn baz<T>(&self) -> &G<T> where T: Copy + 'static;
fn bean<T>(&mut self) -> &mut G<T> where T: Copy + 'static;
}
}

Expand All @@ -21,3 +23,20 @@ fn returning() {
.returning(|g| g.t);
assert_eq!(42u32, mock.bar(G{t: 42}));
}

#[test]
fn return_const() {
let mut mock = MockFoo::new();
mock.expect_baz::<u32>()
.return_const(G{t: 42});
assert_eq!(42u32, mock.baz().t);
}

#[test]
fn return_var() {
let mut mock = MockFoo::new();
mock.expect_bean::<u32>()
.return_var(G{t: 42});
mock.bean::<u32>().t += 1;
assert_eq!(43u32, mock.bean().t);
}
58 changes: 49 additions & 9 deletions mockall_derive/src/mock_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,40 @@ fn ownify(ty: &Type) -> Type {
}
}

/// Add Send + Sync to a where clause
fn send_syncify(wc: &mut Option<WhereClause>, bounded_ty: Type) {
let mut bounds = Punctuated::new();
bounds.push(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: TraitBoundModifier::None,
lifetimes: None,
path: Path::from(format_ident!("Send"))
}));
bounds.push(TypeParamBound::Trait(TraitBound {
paren_token: None,
modifier: TraitBoundModifier::None,
lifetimes: None,
path: Path::from(format_ident!("Sync"))
}));
if wc.is_none() {
*wc = Some(WhereClause {
where_token: <Token![where]>::default(),
predicates: Punctuated::new()
});
}
wc.as_mut().unwrap()
.predicates.push(
WherePredicate::Type(
PredicateType {
lifetimes: None,
bounded_ty,
colon_token: Default::default(),
bounds
}
)
);
}

/// Extract just the type generics from a Generics object
fn type_generics(generics: &Generics) -> Generics {
let params = generics.type_params()
Expand Down Expand Up @@ -159,6 +193,10 @@ impl<'a> Builder<'a> {
}
}
};
if is_static && (return_ref || return_refmut) {
compile_error(self.sig.span(),
"Mockall cannot mock static methods that return non-'static references. It's unclear what the return value's lifetime should be.");
}
let merged_generics = if let Some(g) = self.struct_generics {
merge_generics(g, &declosured_generics)
} else {
Expand Down Expand Up @@ -482,6 +520,12 @@ impl MockFunction {
&self.call_generics
}.split_for_impl();
let (ig, _, wc) = self.call_generics.split_for_impl();
let mut wc = wc.cloned();
if self.is_method_generic() && (self.return_ref || self.return_refmut) {
// Add Senc + Sync, required for downcast, since Expectation
// stores an Option<#owned_output>
send_syncify(&mut wc, self.owned_output.clone());
}
let tbf = tg.as_turbofish();
let vis = &self.call_vis;

Expand Down Expand Up @@ -2204,16 +2248,12 @@ impl<'a> ToTokens for StaticGenericExpectations<'a> {
let argnames = &self.f.argnames;
let argty = &self.f.argty;
let (ig, tg, wc) = self.f.egenerics.split_for_impl();
let owned_output = &self.f.owned_output;
// TODO: test a generic static methd that returns a reference and
// has a where clause
let any_wc = if self.f.return_ref || self.f.return_refmut {
// The Senc + Sync are required for downcast, since Expectation
let mut any_wc = wc.cloned();
if self.f.return_ref || self.f.return_refmut {
// Add Senc + Sync, required for downcast, since Expectation
// stores an Option<#owned_output>
quote!(where #owned_output: Send + Sync)
} else {
quote!(#wc)
};
send_syncify(&mut any_wc, self.f.owned_output.clone());
}
let tbf = tg.as_turbofish();
let output = &self.f.output;
let v = &self.f.privmod_vis;
Expand Down

0 comments on commit 57e5304

Please sign in to comment.