Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update for PyTorch 0.4.0 #21

Open
mratsim opened this issue Jun 10, 2018 · 11 comments
Open

Update for PyTorch 0.4.0 #21

mratsim opened this issue Jun 10, 2018 · 11 comments

Comments

@mratsim
Copy link

mratsim commented Jun 10, 2018

PyTorch 0.4.0 was released on April 24 and unfortunately the pre-trained weights from before are not compatible.

On the notebook I get

style_model = Net(ngf=128)
style_model.load_state_dict(torch.load('21styles.model'), False)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-ce41c62c2272> in <module>()
      1 style_model = Net(ngf=128)
----> 2 style_model.load_state_dict(torch.load('21styles.model'), False)

/usr/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    719         if len(error_msgs) > 0:
    720             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 721                                self.__class__.__name__, "\n\t".join(error_msgs)))
    722 
    723     def parameters(self):

RuntimeError: Error(s) in loading state_dict for Net:
	Unexpected running stats buffer(s) "model1.1.running_mean" and "model1.1.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.3.conv_block.0.running_mean" and "model1.3.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.3.conv_block.3.running_mean" and "model1.3.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.3.conv_block.6.running_mean" and "model1.3.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.4.conv_block.0.running_mean" and "model1.4.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.4.conv_block.3.running_mean" and "model1.4.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model1.4.conv_block.6.running_mean" and "model1.4.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.1.running_mean" and "model.0.1.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.3.conv_block.0.running_mean" and "model.0.3.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.3.conv_block.3.running_mean" and "model.0.3.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.3.conv_block.6.running_mean" and "model.0.3.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.4.conv_block.0.running_mean" and "model.0.4.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.4.conv_block.3.running_mean" and "model.0.4.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.0.4.conv_block.6.running_mean" and "model.0.4.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.2.conv_block.0.running_mean" and "model.2.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.2.conv_block.3.running_mean" and "model.2.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.2.conv_block.6.running_mean" and "model.2.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.3.conv_block.0.running_mean" and "model.3.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.3.conv_block.3.running_mean" and "model.3.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.3.conv_block.6.running_mean" and "model.3.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.4.conv_block.0.running_mean" and "model.4.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.4.conv_block.3.running_mean" and "model.4.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.4.conv_block.6.running_mean" and "model.4.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.5.conv_block.0.running_mean" and "model.5.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.5.conv_block.3.running_mean" and "model.5.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.5.conv_block.6.running_mean" and "model.5.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.6.conv_block.0.running_mean" and "model.6.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.6.conv_block.3.running_mean" and "model.6.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.6.conv_block.6.running_mean" and "model.6.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.7.conv_block.0.running_mean" and "model.7.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.7.conv_block.3.running_mean" and "model.7.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.7.conv_block.6.running_mean" and "model.7.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.8.conv_block.0.running_mean" and "model.8.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.8.conv_block.3.running_mean" and "model.8.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.8.conv_block.6.running_mean" and "model.8.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.9.conv_block.0.running_mean" and "model.9.conv_block.0.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.9.conv_block.3.running_mean" and "model.9.conv_block.3.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.9.conv_block.6.running_mean" and "model.9.conv_block.6.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
	Unexpected running stats buffer(s) "model.10.running_mean" and "model.10.running_var" for InstanceNorm2d with track_running_stats=False. If state_dict is a checkpoint saved before 0.4.0, this may be expected because InstanceNorm2d does not track running stats by default since 0.4.0. Please remove these keys from state_dict. If the running stats are actually needed, instead set track_running_stats=True in InstanceNorm2d to enable them. See the documentation of InstanceNorm2d for details.
@zhanghang1989
Copy link
Owner

set track_running_stats=True in InstanceNorm2d should be able to fix this

@mratsim
Copy link
Author

mratsim commented Jun 10, 2018

track_running_stats = True is buggy and does not work (or I missed something).

I went the other way with the following:

# https://github.com/zhanghang1989/PyTorch-Multi-Style-Transfer/issues/21
# Compatibility shim for PyTorch 0.4

model_dict = torch.load('21styles.model')
model_dict_clone = model_dict.copy() # We can't mutate while iterating

for key, value in model_dict_clone.items():
    if key.endswith(('running_mean', 'running_var')):
        del model_dict[key]

### Next cell

style_model = Net(ngf=128)
style_model.load_state_dict(model_dict, False)

@alvinwan
Copy link

alvinwan commented Aug 5, 2018

  1. I had to downgrade PyTorch to get it working.
pip install torch==0.3.0.post4
  1. In the camera_demo.py and main.py files, the above translates into changing
style_model = Net(ngf=args.ngf)
style_model.load_state_dict(torch.load(args.model))

to

model_dict = torch.load(args.model)
model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
style_model.load_state_dict(model_dict, False)
  1. Change style_v.data() to style_v.data.

@karenerobinson
Copy link

karenerobinson commented Oct 5, 2018

Just got camera_demo.py and main.py working - thanks @alvinwan and @mratsim for the hints above.

For a while I was getting this error:

  File "camera_demo.py", line 105, in <module>
    main()
  File "camera_demo.py", line 102, in main
    run_demo(args, mirror=True)
  File "camera_demo.py", line 75, in run_demo
    simg = simg.transpose(1, 2, 0).astype('uint8')
ValueError: axes don't match array

The quick way to debug is was by replacing my command-line python with python -m pdb and, once it crashed and gave me a prompt, checking the shape of simg. Evidently simg now has 4 dimensions rather than 3, which I fixed with the reshape in step 3 below.

My full fixes were:

1. Downgrade torch:

pip uninstall torch
pip install torch==0.3.0.post4

2. In camera_demo.py and main.py replace

        style_model = Net(ngf=args.ngf)

With

        model_dict = torch.load(args.model) # or args.resume, 
                                           # matching what's in the line with style_model.load_state_dict
        model_dict_clone = model_dict.copy() # We can't mutate while iterating
        for key, value in model_dict_clone.items():
                if key.endswith(('running_mean', 'running_var')):
                        del model_dict[key]

        style_model = Net(ngf=128) # to run with torch-0.3.0.post4
#        style_model = Net(ngf=args.ngf) # to run main.py with torch-0.4.0

Replace

	style_model.load_state_dict(torch.load(args.model)) # or (args.resume) one place

With

        style_model.load_state_dict(model_dict, False)

3. Replace

			simg = style_v.data().numpy()

With

       			simg = style_v.data.numpy().reshape((3,512,512))

QUESTION
Instead of downgrading torch, I also tried setting track_running_stats=True for InstanceNorm2d in net.py. I had to do this in a few places: follow norm_layer through the code, including in the Bottleneck and UpBottleneck classes.

(Note that the documentation shows that track_running_stats=True is the default for most normalization layer classes.)

I've gotten main.py working with torch upgraded, but camera_demo gives an all-black image as output. I'm interested in comments, or ideas!

@mratsim
Copy link
Author

mratsim commented Oct 6, 2018

@karenerobinson I think the most reasonable way would be to wait for PyTorch 1.0 that should happen within days so that APIs are more stable we don't have to fix something new once again once it hits.

@mertgerdan
Copy link

mertgerdan commented Mar 3, 2019

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.

Thanks

@nile649
Copy link

nile649 commented Mar 3, 2019

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.

Thanks

try what @mratsim has mentioned above.
model_dict = torch.load('21styles.model')
model_dict_clone = model_dict.copy() # We can't mutate while iterating

for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]

Next cell

style_model = Net(ngf=128)
style_model.load_state_dict(model_dict, False)

@mertgerdan
Copy link

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.
Thanks

try what @mratsim has mentioned above.
model_dict = torch.load('21styles.model')
model_dict_clone = model_dict.copy() # We can't mutate while iterating

for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]

Next cell

style_model = Net(ngf=128)
style_model.load_state_dict(model_dict, False)

I fixed my issue, I went to NN packages in my python site packages dir and set track_running_stats=True on the instanceNorm file. I didn't know how to do that. After a bit more tweaking, I got it to work. Thanks anyways :)

@zhanghang1989
Copy link
Owner

I really appreciate the comments for fixing the compatibility issue for the code. I haven't worked on this project for a while. Could you consider providing a pull request to the master branch? Thanks a lot :)

@jianchao-li
Copy link
Contributor

jianchao-li commented Aug 3, 2019

Thanks to @alvinwan for sharing the fixes. I have tried it and it worked for both main.py and camera_demo.py. @zhanghang1989 As this is still not fixed in the master branch, I have created a pull request for it (including another fix for load_lua) here.

@IamRafh
Copy link

IamRafh commented Mar 27, 2021

How do you set track_running_stats = True? I am a beginner sorry if it's too obvious I can't find it for the past hour or so.
Thanks

try what @mratsim has mentioned above.
model_dict = torch.load('21styles.model')
model_dict_clone = model_dict.copy() # We can't mutate while iterating
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]

Next cell

style_model = Net(ngf=128)
style_model.load_state_dict(model_dict, False)

I fixed my issue, I went to NN packages in my python site packages dir and set track_running_stats=True on the instanceNorm file. I didn't know how to do that. After a bit more tweaking, I got it to work. Thanks anyways :)

I have been looking for this for 50 hours, thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants