Skip to content

Commit

Permalink
Merge pull request #43 from baraldian/patch-5
Browse files Browse the repository at this point in the history
Restored sampling feature in load_acs (density parameter)
  • Loading branch information
mrtzh committed May 1, 2024
2 parents 731b8d1 + 8367c15 commit 3157807
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions folktables/load_acs.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def load_acs(root_dir, states=None, year=2018, horizon='1-Year',

if states is None:
states = state_list

random.seed(random_seed)

base_datadir = os.path.join(root_dir, str(year), horizon)
os.makedirs(base_datadir, exist_ok=True)

file_names = []
for state in states:
file_names.append(
Expand All @@ -114,7 +114,12 @@ def load_acs(root_dir, states=None, year=2018, horizon='1-Year',
dtypes = {'PINCP': np.float64, 'RT': str, 'SOCP': str, 'SERIALNO': str, 'NAICSP': str}
df_list = []
for file_name in file_names:
df = pd.read_csv(file_name, dtype=dtypes).replace(' ','')
if serial_filter_list is None and density < 1:
skip_prob = 1 - density
df = pd.read_csv(file_name, dtype=dtypes, skiprows=lambda x: x > 0 and random.random() < skip_prob)
else:
df = pd.read_csv(file_name, dtype=dtypes)
df = df.replace(' ', '')
if serial_filter_list is not None:
df = df[df['SERIALNO'].isin(serial_filter_list)]
df_list.append(df)
Expand Down

0 comments on commit 3157807

Please sign in to comment.