diff --git a/deepvision/models/classification/mix_transformer/mit_pt.py b/deepvision/models/classification/mix_transformer/mit_pt.py index 690d84b..4b32940 100644 --- a/deepvision/models/classification/mix_transformer/mit_pt.py +++ b/deepvision/models/classification/mix_transformer/mit_pt.py @@ -85,7 +85,7 @@ def __init__( for i in range(self.num_stages): patch_embed_layer = OverlappingPatchingAndEmbedding( - in_channels=3 if i == 0 else embed_dims[i - 1], + in_channels=input_shape[0] if i == 0 else embed_dims[i - 1], out_channels=embed_dims[0] if i == 0 else embed_dims[i], patch_size=7 if i == 0 else 3, stride=4 if i == 0 else 2, diff --git a/deepvision/models/classification/mix_transformer/mit_tf.py b/deepvision/models/classification/mix_transformer/mit_tf.py index f6ef404..5992f44 100644 --- a/deepvision/models/classification/mix_transformer/mit_tf.py +++ b/deepvision/models/classification/mix_transformer/mit_tf.py @@ -70,7 +70,7 @@ def __init__( for i in range(num_stages): patch_embed_layer = OverlappingPatchingAndEmbedding( - in_channels=3 if i == 0 else embed_dims[i - 1], + in_channels=input_shape[-1] if i == 0 else embed_dims[i - 1], out_channels=embed_dims[0] if i == 0 else embed_dims[i], patch_size=7 if i == 0 else 3, stride=4 if i == 0 else 2,