Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] AssertionError: Values must be float if value_is_float is set to True, got int64: [1] #729

Open
ganyuancao opened this issue Jun 13, 2024 · 3 comments

Comments

@ganyuancao
Copy link

ganyuancao commented Jun 13, 2024

Hello, I am trying to wrap a Unet with Concrete-ML. When I try to compile the model, I got an error

AssertionError: Values must be float if value_is_float is set to True, got int64: [1]

My Unet is like that

import torch
import torch.nn as nn
import brevitas.nn as qnn

class CustomUpsample(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor):
        super(CustomUpsample, self).__init__()
        self.scale_factor = scale_factor
        self.conv = qnn.QuantConv2d(in_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x):
        # Nearest neighbor upsampling using expand and reshape
        batch_size, channels, height, width = x.shape
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = x.expand(batch_size, channels, height, width, self.scale_factor, self.scale_factor)
        x = x.reshape(batch_size, channels, height * self.scale_factor, width * self.scale_factor)
        return self.conv(x)

# Redefine the SmallUNet class to use the CustomUpsample
class SmallUNet(nn.Module):
    def __init__(self):
        super(SmallUNet, self).__init__()
        
        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
            layers = []
            layers += [qnn.QuantConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]
            return nn.Sequential(*layers)
        
        self.enc1 = nn.Sequential(
            CBR2d(in_channels=1, out_channels=32),
            CBR2d(in_channels=32, out_channels=32)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        self.enc2 = nn.Sequential(
            CBR2d(in_channels=32, out_channels=64),
            CBR2d(in_channels=64, out_channels=64)
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        
        self.enc3 = nn.Sequential(
            CBR2d(in_channels=64, out_channels=128),
            CBR2d(in_channels=128, out_channels=128)
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.bottleneck = nn.Sequential(
            CBR2d(in_channels=128, out_channels=256),
            CBR2d(in_channels=256, out_channels=256)
        )
        
        self.up3 = CustomUpsample(in_channels=256, out_channels=128, scale_factor=2)
        self.dec3 = nn.Sequential(
            CBR2d(in_channels=256, out_channels=128),
            CBR2d(in_channels=128, out_channels=128)
        )
        
        self.up2 = CustomUpsample(in_channels=128, out_channels=64, scale_factor=2)
        self.dec2 = nn.Sequential(
            CBR2d(in_channels=128, out_channels=64),
            CBR2d(in_channels=64, out_channels=64)
        )
        
        self.up1 = CustomUpsample(in_channels=64, out_channels=32, scale_factor=2)
        self.dec1 = nn.Sequential(
            CBR2d(in_channels=64, out_channels=32),
            CBR2d(in_channels=32, out_channels=32)
        )
        
        self.final = qnn.QuantConv2d(in_channels=32, out_channels=1, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
                
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        enc3 = self.enc3(self.pool2(enc2))
    
        bottleneck = self.bottleneck(self.pool3(enc3))
    
        up3 = self.up3(bottleneck)
        dec3 = self.dec3(torch.cat((up3, enc3), dim=1))
    
        up2 = self.up2(dec3)
        dec2 = self.dec2(torch.cat((up2, enc2), dim=1))
    
        up1 = self.up1(dec2)
        dec1 = self.dec1(torch.cat((up1, enc1), dim=1))
    
        return torch.sigmoid(self.final(dec1))


# Example of creating and printing the model
model = SmallUNet()
print(model)

And that's how I compile the model

from concrete.ml.torch.compile import compile_brevitas_qat_model

# load the calibration set ONLY BATCH=1 is allowed 
cali_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
cali_image, cali_mask = next(iter(cali_loader))
cali_image = cali_image.cpu()
cali_mask = cali_mask.cpu() 

# load test set 
test_mini_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
test_image, test_mask = next(iter(test_mini_loader))
test_image = test_image.cpu()
test_mask = test_mask.cpu()

fhe_model = compile_brevitas_qat_model(model, cali_image, show_mlir=True)

It is all good when I train and predict in plaintext. Any idea why there is an error like that? Thank you very much in advance!

@RomanBredehoft
Copy link
Collaborator

RomanBredehoft commented Jun 13, 2024

Hello @ganyuancao,
My guess is that you are missing a QuantIdentity as your model's first layer. As we explain a bit in this documentation section, is it needed to make sure inputs are properly quantized.

Since this is not the first time some users are facing similar issues, we'll try to improve our documentation to better explain this part, so thanks for reporting 🙂

Hope that fixes your issue !

@ganyuancao
Copy link
Author

Hello @ganyuancao, My guess is that you are missing a QuantIdentity as your model's first layer. As we explain a bit in this documentation section, is it needed to make sure inputs are properly quantized.

Since this is not the first time some users are facing similar issues, we'll try to improve our documentation to better explain this part, so thanks for reporting 🙂

Hope that fixes your issue !

Hi @RomanBredehoft, Thanks for your answer. Yes I indeed fixed that problem by adding QuantIdentity and change Relu to QuantRelu. But later, it gives me Function you are trying to compile cannot be compiled and printed the graph of the circuit for me. Do you have any clue why this may happen?

@RomanBredehoft
Copy link
Collaborator

Hello again @ganyuancao,
Ah yes indeed, when using torch operators like Relu, you either need to add some QuantIdentity around or use Brevitas' QuantRelu like you did !

As for your question, I would need to know more about your traceback as this is an error from the compiler. Most probably, your issue is that some accumulators are reaching the allowed limit. In any case, we have an example about this issue in this documentation section. Feel free to check out the several possible solutions we provide !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants