diff --git a/CHANGELOG.md b/CHANGELOG.md index b377f53..7044355 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,14 @@ ## Unreleased +## v2.5.0 (2024/08/26) + +- 잘못된 승객 타입이 지정되는 문제 수정 + ([#267](https://github.com/ryanking13/SRT/pull/267)) + ## v2.4.1 (2024/02/09) - Add exception handling - ([#259](https://github.com/ryanking13/SRT/issues/259)) + ([#259](https://github.com/ryanking13/SRT/pull/259)) ## v2.4.0 (2024/01/11) diff --git a/SRT/passenger.py b/SRT/passenger.py index c4404fc..cde41b9 100644 --- a/SRT/passenger.py +++ b/SRT/passenger.py @@ -21,14 +21,18 @@ def __init_internal__(self, name, type_code, count): def __repr__(self): return f"{self.name} {self.count}명" - def __add__(self, other): - assert isinstance(other, self.__class__) + def __add__(self, other: "Passenger") -> "Passenger": + if not isinstance(other, self.__class__): + raise TypeError("Passenger types must be the same") + if self.type_code == other.type_code: new_count = self.count + other.count return self.__class__(count=new_count) + raise ValueError("Passenger types must be the same") + @classmethod - def combine(cls, passengers): + def combine(cls, passengers: list["Passenger"]) -> list["Passenger"]: if list(filter(lambda x: not isinstance(x, Passenger), passengers)): raise TypeError("Passengers must be based on Passenger") @@ -36,18 +40,13 @@ def combine(cls, passengers): combined_passengers = [] while tmp_passengers: passenger = tmp_passengers.pop() - same_class = list( - filter( - lambda x, base_passenger=passenger: isinstance( - x, base_passenger.__class__ - ), - tmp_passengers, - ) - ) - new_passenger = None - if not same_class: - new_passenger = passenger - else: + same_class: list[Passenger] = [] + for p in tmp_passengers: + if isinstance(p, passenger.__class__): + same_class.append(p) + + new_passenger = passenger + if same_class: for same in same_class: new_passenger = passenger + same tmp_passengers.remove(same) @@ -69,7 +68,9 @@ def total_count(passengers): return str(total_count) @staticmethod - def get_passenger_dict(passengers, special_seat=False, window_seat=None): + def get_passenger_dict( + passengers, special_seat=False, window_seat=None + ) -> dict[str, str]: if list(filter(lambda x: not isinstance(x, Passenger), passengers)): raise TypeError("Passengers must be based on Passenger") @@ -77,21 +78,22 @@ def get_passenger_dict(passengers, special_seat=False, window_seat=None): "totPrnb": Passenger.total_count(passengers), "psgGridcnt": str(len(passengers)), } - for i, passenger in enumerate(passengers): - data[f"psgTpCd{i + 1}"] = passenger.type_code - data[f"psgInfoPerPrnb{i + 1}"] = str(passenger.count) + for _, passenger in enumerate(passengers): + code = passenger.type_code + data[f"psgTpCd{code}"] = passenger.type_code + data[f"psgInfoPerPrnb{code}"] = str(passenger.count) # seat location ('000': 기본, '012': 창측, '013': 복도측) - data[f"locSeatAttCd{i + 1}"] = WINDOW_SEAT[window_seat] + data[f"locSeatAttCd{code}"] = WINDOW_SEAT[window_seat] # seat requirement ('015': 일반, '021': 휠체어) # TODO: 선택 가능하게 - data[f"rqSeatAttCd{i + 1}"] = "015" + data[f"rqSeatAttCd{code}"] = "015" # seat direction ('009': 정방향) - data[f"dirSeatAttCd{i + 1}"] = "009" + data[f"dirSeatAttCd{code}"] = "009" - data[f"smkSeatAttCd{i + 1}"] = "000" - data[f"etcSeatAttCd{i + 1}"] = "000" + data[f"smkSeatAttCd{code}"] = "000" + data[f"etcSeatAttCd{code}"] = "000" # seat type: ('1': 일반실, '2': 특실) - data[f"psrmClCd{i + 1}"] = "2" if special_seat else "1" + data[f"psrmClCd{code}"] = "2" if special_seat else "1" return data diff --git a/SRT/srt.py b/SRT/srt.py index 3de8e60..20250e6 100644 --- a/SRT/srt.py +++ b/SRT/srt.py @@ -395,10 +395,6 @@ def _reserve( "dirSeatAttCd1": "009", # 방향좌석속성코드 "locSeatAttCd1": "000", # 위치좌석속성코드1 "rqSeatAttCd1": "015", # 요구좌석속성코드1 - "etcSeatAttCd1": "000", # 기타좌석속성코드1 - "smkSeatAttCd2": "000", # 흡연좌석속성코드2 - "dirSeatAttCd2": "009", # 방향좌석속성코드2 - "rqSeatAttCd2": "015", # 요구좌석속성코드2 "mblPhone": mblPhone, } @@ -410,13 +406,11 @@ def _reserve( } ) - # jobid가 RESERVE_JOBID["PERSONAL"]일 경우, data에 windowSeat 추가 - if jobid == RESERVE_JOBID["PERSONAL"]: - data.update( - Passenger.get_passenger_dict( - passengers, special_seat=is_special_seat, window_seat=window_seat - ) + data.update( + Passenger.get_passenger_dict( + passengers, special_seat=is_special_seat, window_seat=window_seat ) + ) r = self._session.post(url=url, data=data) parser = SRTResponseData(r.text) diff --git a/tests/test_passenger.py b/tests/test_passenger.py new file mode 100644 index 0000000..c19e1a5 --- /dev/null +++ b/tests/test_passenger.py @@ -0,0 +1,44 @@ +from SRT.passenger import ( + Adult, + Child, + Disability1To3, + Disability4To6, + Passenger, + Senior, +) + + +def test_get_passenger_dict(): + passengers = [ + Adult(), + Child(), + Child(), + ] + + passengers = Passenger.combine(passengers) + + data = Passenger.get_passenger_dict(passengers) + assert data["totPrnb"] == "3" + assert data["psgGridcnt"] == "2" + assert data["psgTpCd1"] == "1" + assert data["psgInfoPerPrnb1"] == "1" + assert data["psgTpCd5"] == "5" + assert data["psgInfoPerPrnb5"] == "2" + + passengers2 = [ + Senior(), + Disability1To3(), + Disability4To6(), + ] + + passengers2 = Passenger.combine(passengers2) + + data = Passenger.get_passenger_dict(passengers2) + assert data["totPrnb"] == "3" + assert data["psgGridcnt"] == "3" + assert data["psgTpCd4"] == "4" + assert data["psgInfoPerPrnb4"] == "1" + assert data["psgTpCd2"] == "2" + assert data["psgInfoPerPrnb2"] == "1" + assert data["psgTpCd3"] == "3" + assert data["psgInfoPerPrnb3"] == "1"