Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

copy should not modify data #124

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions stockstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,24 @@

import numpy as np
import pandas as pd
import copy

__author__ = 'Cedric Zhuang'


def wrap(df, index_column=None):
def wrap(df, index_column=None, lowerCase=True):
""" wraps a pandas DataFrame to StockDataFrame

:param df: pandas DataFrame
:param index_column: the name of the index column, default to ``date``
:return: an object of StockDataFrame
"""
return StockDataFrame.retype(df, index_column)

return StockDataFrame.retype(df, index_column, lowerCase)

def unwrap(sdf):
""" convert a StockDataFrame back to a pandas DataFrame """
return pd.DataFrame(sdf)


class StockDataFrame(pd.DataFrame):
# Start of options.
KDJ_PARAM = (2.0 / 3.0, 1.0 / 3.0)
Expand Down Expand Up @@ -1247,7 +1246,7 @@ def __init_column(self, key):

def __getitem__(self, item):
try:
result = wrap(super(StockDataFrame, self).__getitem__(item))
result = wrap(super(StockDataFrame, self).__getitem__(item), lowerCase=False)
except KeyError:
try:
if isinstance(item, list):
Expand All @@ -1257,7 +1256,7 @@ def __getitem__(self, item):
self.__init_column(item)
except AttributeError:
pass
result = wrap(super(StockDataFrame, self).__getitem__(item))
result = wrap(super(StockDataFrame, self).__getitem__(item), lowerCase=False)
return result

def till(self, end_date):
Expand All @@ -1270,7 +1269,7 @@ def within(self, start_date, end_date):
return self.start_from(start_date).till(end_date)

def copy(self, deep=True):
return wrap(super(StockDataFrame, self).copy(deep))
return wrap(super(StockDataFrame, self).copy(deep), lowerCase=False)

def _ensure_type(self, obj):
""" override the method in pandas, omit the check
Expand All @@ -1280,7 +1279,7 @@ def _ensure_type(self, obj):
return obj

@staticmethod
def retype(value, index_column=None):
def retype(value, index_column=None, lowerCase=True):
""" if the input is a `DataFrame`, convert it to this class.

:param index_column: name of the index column, default to `date`
Expand All @@ -1293,8 +1292,9 @@ def retype(value, index_column=None):
if isinstance(value, StockDataFrame):
return value
elif isinstance(value, pd.DataFrame):
# use all lower case for column name
value.columns = map(lambda c: c.lower(), value.columns)
if lowerCase:
# use all lower case for column name
value.columns = map(lambda c: c.lower(), value.columns)

if index_column in value.columns:
value.set_index(index_column, inplace=True)
Expand Down