Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with kmnist dataset #273

Open
bukson opened this issue Jun 20, 2023 · 1 comment
Open

Problem with kmnist dataset #273

bukson opened this issue Jun 20, 2023 · 1 comment

Comments

@bukson
Copy link

bukson commented Jun 20, 2023

Hello

I am trying to use pretrained B_16 model on tfds kmnist dataset (which is similar to mnist in terms of 26x26 greyscale)

Problem is I got error

Initializer expected to generate shape (16, 16, 3, 768) but got shape (16, 16, 1, 768) 

Which is probably due to only 1 color channel instead of 3.

I had no problem with running pretrained model on custom color dataset, is this method only available for 3 channel datasets, or mnist likes are also welcome?

@andsteing
Copy link
Collaborator

I would simply repeat the channels here:

def _pp(data):
im = image_decoder(data['image'])
if mode == 'train':
channels = im.shape[-1]
begin, size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(im),
tf.zeros([0, 0, 4], tf.float32),
area_range=(0.05, 1.0),
min_object_covered=0, # Don't enforce a minimum area.
use_image_if_no_bounding_boxes=True)
im = tf.slice(im, begin, size)
# Unfortunately, the above operation loses the depth-dimension. So we
# need to restore it the manual way.
im.set_shape([None, None, channels])
im = tf.image.resize(im, [image_size, image_size])
if tf.random.uniform(shape=[]) > 0.5:
im = tf.image.flip_left_right(im)
else:
im = tf.image.resize(im, [image_size, image_size])
im = (im - 127.5) / 127.5
label = tf.one_hot(data['label'], num_classes) # pylint: disable=no-value-for-parameter
return {'image': im, 'label': label}

something like

import tensorflow_datasets as tfds
import tensorflow as tf

ds = tfds.load('mnist', split='train')
ds = ds.map(lambda d: {
    'label': d['label'],
    'image': tf.repeat(d['image'], 3, axis=2),
})
ds = ds.batch(2)
b = next(iter(ds))
assert b['image'].shape.as_list() == [2, 28, 28, 3]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants