Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for module import prefix on Python compiler #17286

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/google/protobuf/compiler/python/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ GeneratorOptions Generator::ParseParameter(absl::string_view parameter,
options.annotate_pyi = true;
} else if (option.first == "experimental_strip_nonfunctional_codegen") {
options.strip_nonfunctional_codegen = true;
} else if (option.first == "module_import_prefix") {
options.module_import_prefix =
std::string(absl::StripSuffix(option.second, "."));
} else {
*error = absl::StrCat("Unknown generator option: ", option.first);
}
Expand All @@ -224,6 +227,9 @@ bool Generator::Generate(const FileDescriptor* file,
if (options.strip_nonfunctional_codegen) {
pyi_options.push_back("experimental_strip_nonfunctional_codegen");
}
if (!options.module_import_prefix.empty()) {
pyi_options.push_back(absl::StrCat("module_import_prefix", "=", options.module_import_prefix));
}
if (!pyi_generator.Generate(file, absl::StrJoin(pyi_options, ","), context,
error)) {
return false;
Expand Down Expand Up @@ -289,7 +295,7 @@ bool Generator::Generate(const FileDescriptor* file,
printer_ = &printer;

PrintTopBoilerplate();
PrintImports();
PrintImports(options.module_import_prefix);
PrintFileDescriptor();
printer_->Print("_globals = globals()\n");
if (GeneratingDescriptorProto()) {
Expand Down Expand Up @@ -403,7 +409,7 @@ void Generator::PrintTopBoilerplate() const {
}

// Prints Python imports for all modules imported by |file|.
void Generator::PrintImports() const {
void Generator::PrintImports(absl::string_view module_import_prefix) const {
bool has_importlib = false;
for (int i = 0; i < file_->dependency_count(); ++i) {
absl::string_view filename = file_->dependency(i)->name();
Expand All @@ -414,6 +420,9 @@ void Generator::PrintImports() const {
module_name =
std::string(absl::StripPrefix(module_name, kThirdPartyPrefix));
}
if (!module_import_prefix.empty()) {
module_name = absl::StrCat(module_import_prefix, ".", module_name);
}
if (ContainsPythonKeyword(module_name)) {
// If the module path contains a Python keyword, we have to quote the
// module name and import it using importlib. Otherwise the usual kind of
Expand Down Expand Up @@ -452,6 +461,9 @@ void Generator::PrintImports() const {
module_name =
std::string(absl::StripPrefix(module_name, kThirdPartyPrefix));
}
if (!module_import_prefix.empty()) {
module_name = absl::StrCat(module_import_prefix, ".", module_name);
}
printer_->Print("from $module$ import *\n", "module", module_name);
}
printer_->Print("\n");
Expand Down
4 changes: 3 additions & 1 deletion src/google/protobuf/compiler/python/generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ namespace python {
// CodeGenerator with the CommandLineInterface in your main() function.

struct GeneratorOptions {
std::string module_import_prefix;

bool generate_pyi = false;
bool annotate_pyi = false;
bool bootstrap = false;
Expand Down Expand Up @@ -82,7 +84,7 @@ class PROTOC_EXPORT Generator : public CodeGenerator {
private:
GeneratorOptions ParseParameter(absl::string_view parameter,
std::string* error) const;
void PrintImports() const;
void PrintImports(absl::string_view module_import_prefix) const;
template <typename DescriptorT>
std::string GetResolvedFeatures(const DescriptorT& descriptor) const;
void PrintResolvedFeatures() const;
Expand Down
40 changes: 37 additions & 3 deletions src/google/protobuf/compiler/python/plugin_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ class TestGenerator : public CodeGenerator {
}
};

// opposed to importlib) in the usual case where the .proto file paths do not
// not contain any Python keywords.
TEST(PythonPluginTest, ImportTest) {
void SetupTestFiles() {
// Create files test1.proto and test2.proto with the former importing the
// latter.
ABSL_CHECK_OK(
Expand All @@ -71,6 +69,12 @@ TEST(PythonPluginTest, ImportTest) {
"package foo;\n"
"message Message2 {}\n",
true));
}

// opposed to importlib) in the usual case where the .proto file paths do not
// not contain any Python keywords.
TEST(PythonPluginTest, ImportTest) {
SetupTestFiles();

compiler::CommandLineInterface cli;
cli.SetInputsAreProtoPathRelative(true);
Expand Down Expand Up @@ -100,6 +104,36 @@ TEST(PythonPluginTest, ImportTest) {
EXPECT_TRUE(found_expected_import);
}

TEST(PythonPluginTest, ImportPrefixTest) {
// SetupTestFiles();

compiler::CommandLineInterface cli;
cli.SetInputsAreProtoPathRelative(true);
python::Generator python_generator;
cli.RegisterGenerator("--python_out", &python_generator, "");
std::string proto_path = absl::StrCat("-I", ::testing::TempDir());
std::string python_out = absl::StrCat("--python_out=module_import_prefix=added_prefix:", ::testing::TempDir());
const char* argv[] = {"protoc", proto_path.c_str(), "-I.", python_out.c_str(),
"test1.proto"};
ASSERT_EQ(0, cli.Run(5, argv));

// Loop over the lines of the generated code and verify that we find the
// prefixed import.
std::string output;
ABSL_CHECK_OK(
File::GetContents(absl::StrCat(::testing::TempDir(), "/test1_pb2.py"),
&output, true));
std::vector<absl::string_view> lines = absl::StrSplit(output, '\n');
std::string expected_import = "from added_prefix";
bool found_expected_import = false;
for (absl::string_view line : lines) {
if (absl::StrContains(line, expected_import)) {
found_expected_import = true;
}
}
EXPECT_TRUE(found_expected_import);
}

} // namespace
} // namespace python
} // namespace compiler
Expand Down
21 changes: 16 additions & 5 deletions src/google/protobuf/compiler/python/pyi_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,13 @@ void CheckImportModules(const Descriptor* descriptor,

void PyiGenerator::PrintImportForDescriptor(
const FileDescriptor& desc, absl::flat_hash_set<std::string>* seen_aliases,
bool* has_importlib) const {
bool* has_importlib, absl::string_view module_import_prefix) const {
const std::string& filename = desc.name();
std::string module_name_owned = StrippedModuleName(filename);
if (!module_import_prefix.empty()) {
module_name_owned =
absl::StrCat(module_import_prefix, ".", module_name_owned);
}
absl::string_view module_name(module_name_owned);
size_t last_dot_pos = module_name.rfind('.');
std::string alias = absl::StrCat("_", module_name.substr(last_dot_pos + 1));
Expand Down Expand Up @@ -164,7 +168,7 @@ void PyiGenerator::PrintImportForDescriptor(
}
}

void PyiGenerator::PrintImports() const {
void PyiGenerator::PrintImports(absl::string_view module_import_prefix) const {
// Prints imported dependent _pb2 files.
absl::flat_hash_set<std::string> seen_aliases;
bool has_importlib = false;
Expand All @@ -173,10 +177,11 @@ void PyiGenerator::PrintImports() const {
if (strip_nonfunctional_codegen_ && IsKnownFeatureProto(dep->name())) {
continue;
}
PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib);
PrintImportForDescriptor(*dep, &seen_aliases, &has_importlib,
module_import_prefix);
for (int j = 0; j < dep->public_dependency_count(); ++j) {
PrintImportForDescriptor(*dep->public_dependency(j), &seen_aliases,
&has_importlib);
&has_importlib, module_import_prefix);
}
}

Expand Down Expand Up @@ -263,6 +268,9 @@ void PyiGenerator::PrintImports() const {
for (int i = 0; i < file_->public_dependency_count(); ++i) {
const FileDescriptor* public_dep = file_->public_dependency(i);
std::string module_name = StrippedModuleName(public_dep->name());
if (!module_import_prefix.empty()) {
module_name = absl::StrCat(module_import_prefix, ".", module_name);
}
// Top level messages in public imports
for (int i = 0; i < public_dep->message_type_count(); ++i) {
printer_->Print(
Expand Down Expand Up @@ -582,6 +590,9 @@ bool PyiGenerator::Generate(const FileDescriptor* file,
filename = option.first;
} else if (option.first == "experimental_strip_nonfunctional_codegen") {
strip_nonfunctional_codegen_ = true;
} else if (option.first == "module_import_prefix") {
module_import_prefix =
std::string(absl::StripSuffix(option.second, "."));
} else {
*error = absl::StrCat("Unknown generator option: ", option.first);
return false;
Expand All @@ -603,7 +614,7 @@ bool PyiGenerator::Generate(const FileDescriptor* file,
io::Printer printer(output.get(), printer_opt);
printer_ = &printer;

PrintImports();
PrintImports(module_import_prefix);
printer_->Print("DESCRIPTOR: _descriptor.FileDescriptor\n");

// Prints extensions and enums from imports.
Expand Down
6 changes: 4 additions & 2 deletions src/google/protobuf/compiler/python/pyi_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ class PROTOC_EXPORT PyiGenerator : public google::protobuf::compiler::CodeGenera
private:
void PrintImportForDescriptor(const FileDescriptor& desc,
absl::flat_hash_set<std::string>* seen_aliases,
bool* has_importlib) const;
bool* has_importlib,
absl::string_view module_import_prefix) const;
template <typename DescriptorT>
void Annotate(const std::string& label, const DescriptorT* descriptor) const;
void PrintImports() const;
void PrintImports(absl::string_view module_import_prefix) const;
void PrintTopLevelEnums() const;
void PrintEnum(const EnumDescriptor& enum_descriptor) const;
void PrintEnumValues(const EnumDescriptor& enum_descriptor,
Expand All @@ -92,6 +93,7 @@ class PROTOC_EXPORT PyiGenerator : public google::protobuf::compiler::CodeGenera
mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_.
mutable io::Printer* printer_; // Set in Generate(). Under mutex_.
mutable bool strip_nonfunctional_codegen_ = false; // Set in Generate().
mutable std::string module_import_prefix; // Set in Generate().
// import_map will be a mapping from filename to module alias, e.g.
// "google3/foo/bar.py" -> "_bar"
mutable absl::flat_hash_map<std::string, std::string> import_map_;
Expand Down
Loading