Skip to content

Commit

Permalink
update tf
Browse files Browse the repository at this point in the history
  • Loading branch information
lyy committed Oct 16, 2023
1 parent 1bba935 commit 37247f4
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/shafts/DL_model_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import tensorflow as tf
from keras import layers
from keras.utils.layer_utils import count_params
# from keras.utils.layer_utils import count_params


kernel_size_ref = {100: 20, 250: 40, 500: 80, 1000: 160}
Expand Down Expand Up @@ -825,6 +825,9 @@ def model_SEResNetAuxTF(target_resolution: int, log_scale=False, activation="rel
if psize == 15:
in_plane = 64
num_block = 2
elif psize == 20:
in_plane = 64
num_block = 2
elif psize == 30:
in_plane = 64
num_block = 1
Expand Down Expand Up @@ -853,9 +856,9 @@ def model_SEResNetAuxTF(target_resolution: int, log_scale=False, activation="rel
model.compute_output_shape(input_shape=(None, psize, psize, 7))
model.save(saved_path_tf)

total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights])
trainable_num = sum([count_params(w) for w in model.trainable_weights])
print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num)
# total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights])
# trainable_num = sum([count_params(w) for w in model.trainable_weights])
# print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num)
return model


Expand All @@ -864,6 +867,9 @@ def model_SEResNetMTLAuxTF(target_resolution: int, crossed=False, log_scale=Fals
if psize == 15:
in_plane = 64
num_block = 2
elif psize == 20:
in_plane = 64
num_block = 2
elif psize == 30:
in_plane = 64
num_block = 1
Expand Down Expand Up @@ -893,9 +899,9 @@ def model_SEResNetMTLAuxTF(target_resolution: int, crossed=False, log_scale=Fals
model.compute_output_shape(input_shape=(None, psize, psize, 7))
model.save(saved_path_tf)

total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights])
trainable_num = sum([count_params(w) for w in model.trainable_weights])
print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num)
# total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights])
# trainable_num = sum([count_params(w) for w in model.trainable_weights])
# print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num)
return model


Expand Down

0 comments on commit 37247f4

Please sign in to comment.