From 37247f483ba9362ce302edff1cd19b2a3da2deb0 Mon Sep 17 00:00:00 2001 From: lyy Date: Mon, 16 Oct 2023 15:16:02 +0100 Subject: [PATCH] update tf --- src/shafts/DL_model_tf.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/shafts/DL_model_tf.py b/src/shafts/DL_model_tf.py index 38edc90..d4a9b3b 100644 --- a/src/shafts/DL_model_tf.py +++ b/src/shafts/DL_model_tf.py @@ -4,7 +4,7 @@ import torch import tensorflow as tf from keras import layers -from keras.utils.layer_utils import count_params +# from keras.utils.layer_utils import count_params kernel_size_ref = {100: 20, 250: 40, 500: 80, 1000: 160} @@ -825,6 +825,9 @@ def model_SEResNetAuxTF(target_resolution: int, log_scale=False, activation="rel if psize == 15: in_plane = 64 num_block = 2 + elif psize == 20: + in_plane = 64 + num_block = 2 elif psize == 30: in_plane = 64 num_block = 1 @@ -853,9 +856,9 @@ def model_SEResNetAuxTF(target_resolution: int, log_scale=False, activation="rel model.compute_output_shape(input_shape=(None, psize, psize, 7)) model.save(saved_path_tf) - total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights]) - trainable_num = sum([count_params(w) for w in model.trainable_weights]) - print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num) + # total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights]) + # trainable_num = sum([count_params(w) for w in model.trainable_weights]) + # print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num) return model @@ -864,6 +867,9 @@ def model_SEResNetMTLAuxTF(target_resolution: int, crossed=False, log_scale=Fals if psize == 15: in_plane = 64 num_block = 2 + elif psize == 20: + in_plane = 64 + num_block = 2 elif psize == 30: in_plane = 64 num_block = 1 @@ -893,9 +899,9 @@ def model_SEResNetMTLAuxTF(target_resolution: int, crossed=False, log_scale=Fals model.compute_output_shape(input_shape=(None, psize, psize, 7)) model.save(saved_path_tf) - total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights]) - trainable_num = sum([count_params(w) for w in model.trainable_weights]) - print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num) + # total_num = sum([count_params(w) for w in model.trainable_weights]) + sum([count_params(w) for w in model.non_trainable_weights]) + # trainable_num = sum([count_params(w) for w in model.trainable_weights]) + # print("Total parameter of SEResNet: ", total_num, " Trainable parameter of SEResNet: ", trainable_num) return model