Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR AMP]Split tracer and amp attr #61297

Merged
merged 5 commits into from
Jan 31, 2024
Merged

Conversation

0x45f
Copy link
Contributor

@0x45f 0x45f commented Jan 29, 2024

PR types

Others

PR changes

Others

Description

Pcard-67164

本PR是PIR下支持AMP训练的第一个PR,主要完成tracer和AMP训练相关属性的解绑工作,具体来说:

  • 新增了class AMPState,将原本tracer内的use_promote_、amp_level_、amp_dtype_变量放到了AMPState内。因为PIR下AMP训练也需要这三个属性,但是PIR下tracer为null,无法获取这三个属性,所以添加了一个全局的AMPState来管理AMP相关的属性,这样动态图和PIR下都能拿到AMP训练相关的属性了。
  • 修改AutoCastGuard,AutoCastGuard内部不再和tracer关联,而是和AMPState关联,这样AutoCastGuard在动态图和PIR下都能够使用

NOTE:

  • PR中将tracer中的use_promote_、amp_level_、amp_dtype_属性放到了AMPState,但是在tracer中依然保留了GetUsePromote、SetAmpLevel、GetAmpDtype等相关的接口,本PR修改之后tracer的这些接口内部会调用AMPState对应的接口。为什么没有删除tracer的相关接口?不管是python端还是cpp端还有调用tracer这些相关的接口,这里没有删除只是为了加快PIR AMP功能开发进度,这里作为一个TODO项后面会单独提一个PR或者协调外部开发者来完成这部分清理清理工作。需要清理的相关接口包括但不限于:

    • paddle/fluid/imperative/tracer.h/cc中的GetUsePromote、SetUsePromote、GetAmpLevel、SetAmpLevel、GetAmpDtype、SetAmpDtype
    • paddle/fluid/pybind/imperative.cc中Tracer的_use_promote、_amp_level、_amp_dtype属性
    • 还有python端、cpp端调用相关接口的地方都需要改而调用AMPState相关的接口
  • 对于AmpOperators也是同理,比较好的状态应该是将AmpOperators和AMPState合并,后面也会单独提PR处理

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work !

std::string GetAmpDtype() const;
void SetAmpDtype(std::string amp_dtype);
phi::DataType GetAmpPhiDtype() const;
// void Reset();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

无用注释可以删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个PR删掉,感谢~

@0x45f 0x45f merged commit 250d6a0 into PaddlePaddle:develop Jan 31, 2024
30 checks passed
@0x45f 0x45f deleted the refine-trace-amp branch January 31, 2024 08:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants