Skip to content

Commit

Permalink
remove plotting functionality
Browse files Browse the repository at this point in the history
Removed the plotto dependency and all the .plot() methods. This is a temporary change, plotting functionality will be added back with a different API.
  • Loading branch information
bernardodionisi committed Dec 9, 2023
1 parent 04836a9 commit cded79c
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 534 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ dependencies = [
"pyhdfe>=0.1.2",
"tqdm>=4.64.1",
"joblib>=1.2.0",
"plotto==0.1.3",
"typing_extensions >= 4.0.0",
]
dynamic = ["version"]
Expand All @@ -59,7 +58,7 @@ write_to = "src/differences/_version.py"
where = ["src"]

[tool.setuptools.package-data]
"differences.datasets.data" = ["*.csv", "*.parquet", "datasets/*.csv", "datasets/*.parquet"]
"differences.datasets" = ["*.csv", "*.parquet"]

[tool.distutils.bdist_wheel]
universal = false
Expand Down
207 changes: 51 additions & 156 deletions src/differences/attgt/attgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from ..tools.panel_utility import (delta_col_to_create,
find_time_varying_covars, is_panel_balanced)
from ..tools.panel_validation import _ValiDIData
from ..tools.utility import capitalize_details
from . import plot as attgt_plot

from .aggregate import _AggregateGT, get_weights
from .attgt_cal import get_att_gt, get_standard_errors
from .difference import (_Difference, difference_ntl_to_dataframe,
Expand Down Expand Up @@ -87,13 +86,13 @@ class ATTgt:
data = _ValiDIData()

def __init__(
self,
data: DataFrame,
cohort_name: str,
strata_name: str = None, # extra treatment information
base_period: str = "varying", # or 'universal'
anticipation: int = 0,
freq: str = None,
self,
data: DataFrame,
cohort_name: str,
strata_name: str = None, # extra treatment information
base_period: str = "varying", # or 'universal'
anticipation: int = 0,
freq: str = None,
):

# from now on 'base_period' is called 'base_period_type'
Expand Down Expand Up @@ -228,7 +227,7 @@ def group_time(self, feasible: bool = False) -> list[dict]:
d
for d in cbt
if (d["time"], d["cohort"]) in self._feasible_gt
or (d["base_period"] == d["time"])
or (d["base_period"] == d["time"])
]

return cbt
Expand All @@ -241,7 +240,7 @@ def group_time(self, feasible: bool = False) -> list[dict]:
d
for d in cbtg
if (d["time"], d["cohort"], d["stratum"]) in self._feasible_gt
or (d["base_period"] == d["time"])
or (d["base_period"] == d["time"])
]

return cbtg
Expand All @@ -268,7 +267,7 @@ def _create_data_matrix(self, is_panel: bool):
return y_matrix

def _preprocess_covariates(
self, is_panel: bool, base_delta: str | list | dict, y_matrix
self, is_panel: bool, base_delta: str | list | dict, y_matrix
) -> None:

if is_panel: # y_matrix is available only if is_panel
Expand Down Expand Up @@ -315,7 +314,7 @@ def _preprocess_covariates(
self._x_base_delta = {"base": self._x_base, "delta": self._x_delta}

def _create_result_dict(
self, split_sample_by: Callable | str | dict | None
self, split_sample_by: Callable | str | dict | None
) -> None:
"""
creates self._result_dict dictionary
Expand Down Expand Up @@ -358,11 +357,11 @@ def _create_result_dict(
self._result_dict = {"full_sample": {}}

def _get_clusters_for_difference(
self,
cluster_var: list | str | None,
difference_samples: list,
data_mask: np.ndarray,
iterate_samples: list,
self,
cluster_var: list | str | None,
difference_samples: list,
data_mask: np.ndarray,
iterate_samples: list,
) -> np.ndarray | dict:

# clusters
Expand Down Expand Up @@ -395,21 +394,21 @@ def _get_clusters_for_difference(

# att gt
def fit(
self,
formula: str,
weights_name: str = None,
control_group: str = "never_treated",
base_delta: str | list | dict = "base",
est_method: str | Callable = "dr",
as_repeated_cross_section: bool = None,
boot_iterations: int = 0, # if > 0 mboot will be called
random_state: int = None,
alpha: float = 0.05,
cluster_var: list | str = None,
split_sample_by: Callable | str | dict = None,
n_jobs: int = 1,
backend: str = "loky",
progress_bar: bool = True,
self,
formula: str,
weights_name: str = None,
control_group: str = "never_treated",
base_delta: str | list | dict = "base",
est_method: str | Callable = "dr",
as_repeated_cross_section: bool = None,
boot_iterations: int = 0, # if > 0 mboot will be called
random_state: int = None,
alpha: float = 0.05,
cluster_var: list | str = None,
split_sample_by: Callable | str | dict = None,
n_jobs: int = 1,
backend: str = "loky",
progress_bar: bool = True,
) -> DataFrame:
"""
Computes the cohort-time-(stratum) average treatment effects:
Expand Down Expand Up @@ -710,7 +709,7 @@ def fit(
progress_bar=progress_bar,
sample_name=s if s != "full_sample" else None, # just for progress_bar
release_workers=(
not bool(boot_iterations) and (s_idx + 1 == n_sample_names)
not bool(boot_iterations) and (s_idx + 1 == n_sample_names)
),
)

Expand Down Expand Up @@ -738,16 +737,16 @@ def fit(
return self._fit_res

def aggregate(
self,
type_of_aggregation: str | None = "simple",
overall: bool = False,
difference: bool | list | dict[str, list] = False,
alpha: float = 0.05,
cluster_var: list | str = None,
boot_iterations: int = 0,
random_state: int = None,
n_jobs: int = 1,
backend: str = "loky",
self,
type_of_aggregation: str | None = "simple",
overall: bool = False,
difference: bool | list | dict[str, list] = False,
alpha: float = 0.05,
cluster_var: list | str = None,
boot_iterations: int = 0,
random_state: int = None,
n_jobs: int = 1,
backend: str = "loky",
) -> DataFrame:
"""
Aggregate the ATTgt
Expand Down Expand Up @@ -866,7 +865,7 @@ def aggregate(
return self._fit_res

if isinstance(
cluster_var, str
cluster_var, str
): # entity cluster is automatic, exclude from list
cluster_var = [
c for c in [cluster_var] if c != self._data_matrix.index.names[0]
Expand Down Expand Up @@ -1131,13 +1130,13 @@ def estimation_details(self, type_of_aggregation: str = None):
return details

def results(
self,
type_of_aggregation: str = None,
overall: bool = False,
difference: bool = False,
# sample_name: str = None,
to_dataframe: bool = True,
add_info: bool = False,
self,
type_of_aggregation: str = None,
overall: bool = False,
difference: bool = False,
# sample_name: str = None,
to_dataframe: bool = True,
add_info: bool = False,
):
"""
provides easy access to cached results.
Expand Down Expand Up @@ -1229,107 +1228,3 @@ def results(
return output

return output

def plot(
self,
type_of_aggregation: str = None,
overall: bool = False,
difference: bool = False, # I need this mainly to retrieve the correct
# sample_name: str = None,
estimation_details: bool = True,
estimate_in_x_axis: bool = False,
**plotting_parameters,
):
"""
Parameters
----------
type_of_aggregation: *str* | None, default: ``None``
- ``"simple"``
to plot the weighted average of all cohort-time average treatment effects,
with weights proportional to the cohort size.
- ``"event"`` or ``"event"``
to plot the average effects in each relative period:
periods relative to the treatment; as in an event study.
- ``"cohort"``
to plot the average treatment effect in each cohort.
- ``"time"`` or ``"time"``
to plot the average treatment effect in each time time.
overall: *bool*, default: ``False``
to plot the average effect within each type_of_aggregation.
- if type_of_aggregation is set to ``"event"`` or ``"event"``
to plot the average effect of the treatment across positive relative periods
- if type_of_aggregation is set to ``"cohort"``
to plot the average effect of the treatment across cohorts
- if type_of_aggregation is set to ``"time"`` or ``"time"``
to plot the average effect of the treatment across time times
difference: *bool*, default: ``False``
take the difference of the estimates
Available options are:
- ``True``
to plot the difference between 2 samples or 2 strata of treatments
estimation_details: *bool* | *list* | *str*, default: ``True``
include the estimation details in the plot. One can modify the format
through plotting_parameters
estimate_in_x_axis: *bool*, default: ``False``
whether to display the ATT estimates in the x-axis
plotting_parameters
a set of parameters to customize the plot. Please refer to the separate documentation
for the plotting functionalities built in the library
Returns
-------
An interactive plot for the requested estimates
"""
if isinstance(estimation_details, bool):
if estimation_details:
estimation_details = capitalize_details(
estimation_details=self.estimation_details(
type_of_aggregation=type_of_aggregation
)
)

df = self.results(
type_of_aggregation=type_of_aggregation,
overall=overall,
difference=difference,
# sample_name=sample_name,
to_dataframe=True,
add_info=not bool(type_of_aggregation),
)

if type_of_aggregation is None:
return attgt_plot.plot_att_gt(
df=df,
plotting_parameters=plotting_parameters,
estimation_details=estimation_details,
)

elif type_of_aggregation == "simple" or overall:
return attgt_plot.plot_overall_agg(
df=df,
plotting_parameters=plotting_parameters,
estimation_details=estimation_details,
)

elif not overall:

plot_func = getattr(attgt_plot, f"plot_{type_of_aggregation}_agg")
return plot_func(
df=df,
plotting_parameters=plotting_parameters,
estimation_details=estimation_details,
)
Loading

0 comments on commit cded79c

Please sign in to comment.