Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Add RegAE example #660

Merged
merged 7 commits into from
Jan 23, 2024
Merged

Conversation

xusuyong
Copy link
Contributor

@xusuyong xusuyong commented Nov 22, 2023

PR types

New features

PR changes

Others

Describe

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

辛苦修改一下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

markdown文档用 vscode markdownlint 插件格式化一下

Comment on lines 27 to 33
criterion = nn.MSELoss()
kl_loss = KLLoss()


def loss_expr(output_dict, label_dict, weight_dict=None):

return kl_loss(output_dict) + criterion(output_dict["p_hat"], label_dict["p_hat"])
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里用F.mse_loss和KLLoss里的函数吧,简化下代码
  2. loss_expr 函数定义放到sup_constraint = ... 的上方,不用作为全局函数
  3. import不要直接导入某函数或者类,通过导入上一级的模块再访问

# set random seed for reproducibility
ppsci.utils.misc.set_random_seed(cfg.seed)
# initialize logger
logger.init_logger("ppsci", osp.join(cfg.output_dir, "train.log"), "info")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

==> eval.log

@@ -0,0 +1,142 @@
"""
输入数据类型 10^5 * 100 * 100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据类型为什么是一个乘法数值?不应该是float或者double吗?还是说10^5表示数值范围,后面的100*100表示样本数?感觉表述可以更改得准确一点


def transform(self, data):
mean = (
paddle.to_tensor(self.mean).type_as(data).to(data.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

==> to_tensor(self.mean, dtype=data.dtype)

self.train_data = data[: self.train_len]
self.test_data = data[self.train_len :]

self.scaler = ScalerStd()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaler跟AMP的scaler存在歧义,改为 normalizer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API文档里加一下AutoEncoder

from ppsci.arch import base


# copy from AISTUDIO
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行删掉吧

Comment on lines 26 to 34
class AutoEncoder(base.Arch):
def __init__(
self,
input_keys: Tuple[str, ...],
output_keys: Tuple[str, ...],
input_dim,
latent_dim,
hidden_dim,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

完善docstring、type hint

Comment on lines +40 to +45
mu, log_sigma = output_dict["mu"], output_dict["log_sigma"]

base = paddle.exp(2.0 * log_sigma) + paddle.pow(mu, 2) - 1.0 - 2.0 * log_sigma
loss = 0.5 * paddle.sum(base) / mu.shape[0]

return loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是KL散度损失吗?但是看起来跟nn.KLDivLoss不是很像,如果不是很通用的Loss的话还是放在案例里通过FunctionalLoss使用吧,如果是比较通用的,就需要写成KLDiv的形式,即 $p*(\ln{\frac{p}{q}})$ ,并写好带有数学公式的 docstring

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 254278a into PaddlePaddle:develop Jan 23, 2024
3 checks passed
@xusuyong xusuyong deleted the RegAE branch April 22, 2024 08:29
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* add RegAE example

* add RegAE

---------

Co-authored-by: HydrogenSulfate <490868991@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants