Skip to content

Commit

Permalink
fix(interactive): fix unexpect result of operator aggregate() + itera…
Browse files Browse the repository at this point in the history
…te() (#3391)

Fixes #3177
  • Loading branch information
lnfjpt authored Dec 1, 2023
1 parent dbf2803 commit 901b743
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cmp::max;

use crate::api::{IterCondition, Iteration};
use crate::macros::filter::*;
use crate::stream::Stream;
Expand Down Expand Up @@ -53,7 +55,7 @@ where
F: FnOnce(Stream<D>) -> Result<Stream<D>, BuildJobError>,
{
let max_iters = until.max_iters;
let (leave, enter) = stream
let (mut leave, enter) = stream
.enter()?
.binary_branch_notify("switch", |info| {
SwitchOperator::<D>::new(info.scope_level, emit_kind, until)
Expand All @@ -65,7 +67,10 @@ where
.transform_notify("feedback", move |info| {
FeedbackOperator::<D>::new(info.scope_level, max_iters)
})?;
let feedback_partitions = feedback.get_partitions();
feedback.feedback_to(index)?;
let partition_update = max(feedback_partitions, leave.get_partitions());
leave.set_partitions(partition_update);
leave.leave()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ impl<D: Data> Stream<D> {
pub fn get_partitions(&self) -> usize {
self.partitions
}

pub fn set_partitions(&mut self, partitions: usize) {
self.partitions = partitions;
}
}

impl<D: Data> Stream<D> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
//! See the License for the specific language governing permissions and
//! limitations under the License.
//
use pegasus::api::{CorrelatedSubTask, Count, EmitKind, Fold, IterCondition, Iteration, Map, Reduce, Sink};
use pegasus::api::{
CorrelatedSubTask, Count, EmitKind, Fold, IterCondition, Iteration, Limit, Map, Reduce, Sink,
};
use pegasus::JobConf;

#[test]
Expand Down Expand Up @@ -119,6 +121,71 @@ fn iterate_emit_before_x_r_map_x_test() {
assert_eq!(results, expected);
}

#[test]
fn aggregate_iterate_count_test() {
let mut conf = JobConf::new("aggregate_iterate_count_test");
conf.set_workers(2);
let mut result_stream = pegasus::run(conf, || {
|input, output| {
input
.input_from(0..10000u32)?
.limit(1000)?
.filter_map(|i| Ok(Some(i)))?
.iterate_emit_until(IterCondition::max_iters(1), EmitKind::Before, |start| {
Ok(start
.flat_map(|x| Ok(vec![x + 1].into_iter()))?
.repartition(|x| Ok((x % 32) as u64)))
})?
.flat_map(|x| Ok(vec![x + 1].into_iter()))?
.count()?
.sink_into(output)
}
})
.expect("build job failure");

let mut count = 0;
let mut value = 0;
while let Some(Ok(d)) = result_stream.next() {
count += 1;
value = d;
}
assert_eq!(count, 1);
assert_eq!(value, 2000);
}

#[test]
fn iterate_aggregate_count_test() {
let mut conf = JobConf::new("aggregate_iterate_count_test");
conf.set_workers(2);
let mut result_stream = pegasus::run(conf, || {
|input, output| {
input
.input_from(0..10000u32)?
.limit(1000)?
.repartition(|x| Ok((x % 32) as u64))
.filter_map(|i| Ok(Some(i)))?
.iterate_emit_until(IterCondition::max_iters(1), EmitKind::Before, |start| {
Ok(start
.flat_map(|x| Ok(vec![x + 1].into_iter()))?
.aggregate())
})?
.flat_map(|x| Ok(vec![x + 1].into_iter()))?
.count()?
.sink_into(output)
}
})
.expect("build job failure");

let mut count = 0;
let mut value = 0;
while let Some(Ok(d)) = result_stream.next() {
count += 1;
value = d;
}
assert_eq!(count, 1);
assert_eq!(value, 2000);
}

#[test]
fn ping_pong_test_01() {
let mut conf = JobConf::new("ping_pong_test_01");
Expand Down

0 comments on commit 901b743

Please sign in to comment.