Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jul 10, 2024
1 parent 53ca74b commit afc9727
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def best_video_stream(self) -> StreamMetadata:
return self.streams[self.best_video_stream_index]


def get_video_metadata(decoder: torch.tensor) -> VideoMetadata:
def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:

container_dict = json.loads(_get_container_json_metadata(decoder))
streams_metadata = []
Expand Down
16 changes: 8 additions & 8 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,23 @@ def get_frames_in_range_abstract(


@impl_abstract("torchcodec_ns::get_json_metadata")
def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")
def get_json_metadata_abstract(decoder: torch.Tensor) -> torch.Tensor:
return torch.empty_like("") # type: ignore[arg-type]


@impl_abstract("torchcodec_ns::get_container_json_metadata")
def get_container_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")
def get_container_json_metadata_abstract(decoder: torch.Tensor) -> torch.Tensor:
return torch.empty_like("") # type: ignore[arg-type]


@impl_abstract("torchcodec_ns::get_stream_json_metadata")
def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> str:
return torch.empty_like("")
def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> torch.Tensor:
return torch.empty_like("") # type: ignore[arg-type]


@impl_abstract("torchcodec_ns::_get_json_ffmpeg_library_versions")
def _get_json_ffmpeg_library_versions_abstract() -> str:
return torch.empty_like("")
def _get_json_ffmpeg_library_versions_abstract() -> torch.Tensor:
return torch.empty_like("") # type: ignore[arg-type]


@impl_abstract("torchcodec_ns::scan_all_streams_to_update_metadata")
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, source: Union[str, Path, bytes, torch.Tensor]):
core.add_video_stream(self._decoder)

self.stream_metadata = _get_and_validate_stream_metadata(self._decoder)
self._num_frames = self.stream_metadata.num_frames_computed
self._num_frames: int = self.stream_metadata.num_frames_computed # type: ignore[assignment]
self._stream_index = self.stream_metadata.stream_index

def __len__(self) -> int:
Expand Down
8 changes: 3 additions & 5 deletions src/torchcodec/samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
self.sampler_args = sampler_args
self.decorder_args = DecoderArgs() if decorder_args is None else decorder_args

def forward(self, video_data: Tensor) -> Union[List[List[Tensor]], List[Tensor]]:
def forward(self, video_data: Tensor) -> Union[List[Any]]:
"""Sample video clips from the video data
Args:
Expand Down Expand Up @@ -162,8 +162,7 @@ def forward(self, video_data: Tensor) -> Union[List[List[Tensor]], List[Tensor]]
num_threads=self.decorder_args.num_threads,
)

clips = []

clips: List[Any] = []
# Cast sampler args to be time based or index based
if isinstance(self.sampler_args, TimeBasedSamplerArgs):
time_based_sampler_args = self.sampler_args
Expand Down Expand Up @@ -218,7 +217,6 @@ def _get_clips_for_index_based_sampling(
)
sampler_type = index_based_sampler_args.sampler_type

clip_start_idxs = []
if sampler_type == "random":
clip_start_idxs = torch.randint(
sample_start_index,
Expand Down Expand Up @@ -252,7 +250,7 @@ def _get_start_seconds(
self,
metadata_json: Dict[str, Any],
time_based_sampler_args: TimeBasedSamplerArgs,
):
) -> List[float]:
"""Get start seconds for each clip.
Given different sampler type, the API returns different clip start seconds.
Expand Down

0 comments on commit afc9727

Please sign in to comment.