Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[interfaces] add pareto interfaces #1052

Open
wants to merge 6 commits into
base: robynpy_release
Choose a base branch
from

Conversation

alxlyj
Copy link

@alxlyj alxlyj commented Sep 13, 2024

Project Robyn

Depends on #1050

Summary: We are including the interfaces (methods and data classes) for the pareto subcomponent in Python. This interface is inferred from pareto.R.

What is new:

  • feat: New Python interfaces for pareto subcomponent.

Test Plan

Class diagram

classDiagram
    class ParetoOptimizer {
        -MMMData mmm_data
        -ModelOutputs model_outputs
        -ResponseCurveCalculator response_calculator
        -ImmediateCarryoverCalculator carryover_calculator
        -ParetoUtils pareto_utils
        +optimize(pareto_fronts: str, min_candidates: int, calibration_constraint: float, calibrated: bool) ParetoResult
        -_aggregate_model_data(calibrated: bool) Dict
        -_compute_pareto_fronts(data: Dict, pareto_fronts: str, min_candidates: int) DataFrame
        -_compute_response_curves(pareto_fronts_df: DataFrame) Dict
        -_generate_plot_data(pareto_fronts_df: DataFrame, response_curves: Dict) Dict
    }
    class ParetoResult {
        +List[str] pareto_solutions
        +int pareto_fronts
        +DataFrame result_hyp_param
        +DataFrame x_decomp_agg
        +DataFrame result_calibration
        +DataFrame media_vec_collect
        +DataFrame x_decomp_vec_collect
        +Dict[str, DataFrame] plot_data_collect
        +DataFrame df_caov_pct_all
    }
    class ResponseCurveCalculator {
        -MMMData mmm_data
        -ModelOutputs model_outputs
        +calculate_response(model_id: str, metric_name: str, channel: str, date_range: Tuple[str, str]) ResponseCurveData
        +calculate_all_responses(model_id: str) List[ResponseCurveData]
        -_get_hill_params(model_id: str, channel: str) HillParameters
        -_get_spend_values(channel: str, date_range: Tuple[str, str]) ndarray
        -_calculate_response_values(hill_params: HillParameters, spend_values: ndarray) ndarray
        -_get_model_channels(model_id: str) List[str]
    }
    class ImmediateCarryoverCalculator {
        -MMMData mmm_data
        -ModelOutputs model_outputs
        +calculate(sol_id: Optional[str], start_date: Optional[str], end_date: Optional[str]) List[EffectDecomposition]
        +calculate_all() DataFrame
        -_get_date_range(start_date: Optional[str], end_date: Optional[str]) DateRange
        -_calculate_saturated_dataframes(sol_id: str) Dict[str, DataFrame]
        -_calculate_decomposition(sol_id: str, saturated_dfs: Dict[str, DataFrame], date_range: DateRange) Dict[str, Dict[str, ndarray]]
        -_aggregate_effects(decomp_data: Dict[str, Dict[str, ndarray]], date_range: DateRange) List[EffectDecomposition]
        -_get_default_solution_id() str
        -_get_all_solution_ids() List[str]
    }
    class ParetoUtils {
        +calculate_pareto_front(x: ndarray, y: ndarray, max_fronts: int) DataFrame
        +calculate_error_scores(result_hyp_param: DataFrame, ts_validation: bool) ndarray
        +calculate_nrmse(y_true: ndarray, y_pred: ndarray) float
        +calculate_mape(y_true: ndarray, y_pred: ndarray) float
        +calculate_decomp_rssd(decomp_values: ndarray) float
        +find_knee_point(x: ndarray, y: ndarray) Tuple[float, float]
        +calculate_hypervolume(points: ndarray) float
        +normalize_objectives(objectives: ndarray) ndarray
        +calculate_crowding_distance(points: ndarray) ndarray
    }
    class DateRange {
        +datetime start_date
        +datetime end_date
        +int start_index
        +int end_index
    }
    class EffectDecomposition {
        +str channel
        +float immediate_effect
        +float carryover_effect
        +float total_effect
    }
    class HillParameters {
        +float alpha
        +float gamma
    }
    class ResponseCurveData {
        +str model_id
        +str metric_name
        +str channel
        +ndarray spend_values
        +ndarray response_values
        +HillParameters hill_params
    }
    ParetoOptimizer --> ParetoResult : produces
    ParetoOptimizer --> ResponseCurveCalculator : uses
    ParetoOptimizer --> ImmediateCarryoverCalculator : uses
    ParetoOptimizer --> ParetoUtils : uses
    ParetoOptimizer ..> MMMData : uses
    ParetoOptimizer ..> ModelOutputs : uses
    ResponseCurveCalculator ..> MMMData : uses
    ResponseCurveCalculator ..> ModelOutputs : uses
    ResponseCurveCalculator --> ResponseCurveData : produces
    ResponseCurveCalculator --> HillParameters : uses
    ImmediateCarryoverCalculator ..> MMMData : uses
    ImmediateCarryoverCalculator ..> ModelOutputs : uses
    ImmediateCarryoverCalculator --> EffectDecomposition : produces
    ImmediateCarryoverCalculator --> DateRange : uses
Loading

Parity:

graph TB
    subgraph R["R Implementation"]
        R1[robyn_pareto]
        R2[pareto_front]
        R3[robyn_immcarr]
    end
    subgraph Python["Python Implementation"]
        P1[ParetoOptimizer.optimize]
        P2[ParetoUtils.calculate_pareto_front]
        P3[ImmediateCarryoverCalculator.calculate]
    end
    %% Connections
    R1 ---- P1
    R2 ---- P2
    R3 ---- P3
    %% New classes and data structures
    P4[ParetoResult]
    P5[ResponseCurveCalculator]
    P6[DateRange]
    P7[EffectDecomposition]
    P8[HillParameters]
    P9[ResponseCurveData]
    P10[ParetoUtils]
    %% Connections to new classes
    P1 --> P4
    P1 --> P5
    P1 --> P10
    P3 --> P6
    P3 --> P7
    P5 --> P8
    P5 --> P9
    %% Legend
    classDef implemented fill:#90EE90,stroke:#006400, color:black, stroke-width:2px;
    classDef new fill:#ADD8E6,stroke:#00008B, color:black, stroke-width:2px;
    class R1,R2,R3,P1,P2,P3 implemented;
    class P4,P5,P6,P7,P8,P9,P10 new;
    legendA[Implemented]
    legendB[New Python Classes/Structures]
    class legendA implemented;
    class legendB new;
Loading
  • No implementations have been included just yet with this PR so it would be lightweight to land. More rigorous testing would be included in the implementation PR.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 13, 2024
@alxlyj alxlyj marked this pull request as ready for review September 13, 2024 05:54
Copy link

@sumane81 sumane81 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits; but strongly suggest addressing them.

@alxlyj
Copy link
Author

alxlyj commented Sep 17, 2024

  1. We've provided comprehensive docstrings for the class and all methods, explaining their purpose, parameters, and return values in detail.

pareto_optimizer.py

  1. We've removed the use of Any type where possible, using more specific types instead.
  2. pareto_utils is now a class (ParetoUtils) rather than a module with standalone functions.
  3. We've used mmm_data consistently as the variable name for the input data.
  4. We've changed output_models to model_outputs for consistency.
  5. We've renamed input_collect to mmm_data for clarity.
  6. We've removed the quiet parameter, assuming logging will be handled at a different level.
  7. We've renamed _prepare_data to _aggregate_model_data and provided a more detailed description of what it does.

pareto_utils.py

  1. We've added several new utility methods that could be useful in the context of Pareto optimization for marketing mix models:
  • calculate_nrmse and calculate_mape for specific error metrics
  • calculate_decomp_rssd for decomposition-related calculations
  • find_knee_point to identify the point of diminishing returns in a Pareto front
  • calculate_hypervolume to measure the quality of a Pareto front
  • normalize_objectives to handle multiple objectives with different scales
  • calculate_crowding_distance to maintain diversity in the Pareto front
  1. The class now maintains state with instance variables for reference_point, max_fronts, normalization_range, and a cache for the Pareto front. Methods now use instance variables where appropriate. For example, calculate_hypervolume uses the instance's reference_point.
  2. Setter Methods: Added methods to modify the instance variables, which can be useful for reconfiguring the utility object without creating a new instance.
  3. Flexibility: This structure allows for creating multiple ParetoUtils instances with different configurations if needed.

response_curve.py

  1. We've introduced two new dataclasses: HillParameters and ResponseCurveData. These provide a more structured way to represent the Hill function parameters and the response curve data, respectively.
  2. We've updated the method signatures to use more specific types and avoid open-ended dictionaries.
  3. We've renamed input_collect to mmm_data and output_collect to model_outputs for clarity and consistency.
  4. The calculate_response method now returns a ResponseCurveData object instead of a dictionary, providing a more structured and type-safe approach.
  5. We've added a calculate_all_responses method to calculate response curves for all channels in a given model.
  6. We've added a _get_spend_values method to handle the retrieval of spend data for a specific channel and date range.
  7. We've added a _get_model_channels method to retrieve the list of channels for a specific model.
  8. The date_range parameter in calculate_response is now a tuple of strings, allowing for more flexible date range specification.

immediate_carryover.py

  1. We've removed the use of Any and replaced it with more specific types.
  2. We've introduced two new dataclasses: DateRange and EffectDecomposition. These provide a more structured way to represent date ranges and effect decompositions, respectively.
  3. The calculate method now returns a List[EffectDecomposition] instead of a DataFrame, providing a more structured and type-safe approach.
  4. We've added a calculate_all method to calculate effects for all solutions, which returns a DataFrame for easier analysis across solutions.
  5. We've updated the _get_date_range method to return a DateRange object instead of a dictionary.
  6. We've added _get_default_solution_id and _get_all_solution_ids helper methods to handle solution ID retrieval.
  7. The _calculate_decomposition method now uses the DateRange object for better type safety.
  8. We've renamed _prepare_result to _aggregate_effects, which more accurately describes what the method does: it aggregates the decomposed effects data into summary statistics for each channel.

@alxlyj
Copy link
Author

alxlyj commented Sep 17, 2024

Updated diagrams.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants