diff --git a/atcoder/__main__.py b/atcoder/__main__.py index 03856c0..9040434 100644 --- a/atcoder/__main__.py +++ b/atcoder/__main__.py @@ -69,7 +69,6 @@ def import_module(self, import_from: Optional[str], name: str, imports = iter_child_nodes(ast.parse(source)) import_lines = [] - import_list = [] for import_info in imports: result += self.import_module( import_info.import_from, import_info.name, @@ -78,11 +77,6 @@ def import_module(self, import_from: Optional[str], name: str, import_info.end_lineno): import_lines.append(line) - if import_info.import_from is None: - import_list.append(import_info.name) - else: - import_list.append(import_info.import_from) - for lineno, line in enumerate(lines): if lineno not in import_lines: continue @@ -97,18 +91,24 @@ def import_module(self, import_from: Optional[str], name: str, result += '\n'.join(lines) result += '"""\n\n' result += f"{module_name} = types.ModuleType('{module_name}')\n" - result += f'exec({code}, {module_name}.__dict__)\n' + # TODO(not): asname imported = [] - for import_ in import_list: - modules = import_.split('.') - for i in range(len(modules)): - import_name = '.'.join(modules[:i + 1]) - if import_name in imported: - continue - imported.append(import_name) - result += f"{module_name}.__dict__['{import_name}']" \ - f" = {import_name}\n" + for import_info in imports: + if import_info.import_from is None: + modules = import_info.name.split('.') + for i in range(len(modules)): + import_name = '.'.join(modules[:i + 1]) + if import_name in imported: + continue + imported.append(import_name) + result += f"{module_name}.__dict__['{import_name}']" \ + f" = {import_name}\n" + else: + result += f"{module_name}.__dict__['{import_info.name}']" \ + f" = {import_info.import_from}.{import_info.name}\n" + + result += f'exec({code}, {module_name}.__dict__)\n' if import_from is None: if asname is None: