OLD | NEW |
1 # Copyright 2012 The Chromium Authors. All rights reserved. | 1 # Copyright 2017 The Chromium Authors. All rights reserved. |
2 # Use of this source code is governed by a BSD-style license that can be | 2 # Use of this source code is governed by a BSD-style license that can be |
3 # found in the LICENSE file. | 3 # found in the LICENSE file. |
4 | 4 |
5 import fnmatch | 5 """Shim for discover; will be removed once #3612 is fixed.""" |
6 import inspect | |
7 import os | |
8 import re | |
9 import sys | |
10 | 6 |
11 from py_utils import camel_case | 7 from py_utils import discover |
12 | 8 |
13 | 9 DiscoverClasses = discover.DiscoverClasses |
14 def DiscoverModules(start_dir, top_level_dir, pattern='*'): | 10 DiscoverClassesInModule = discover.DiscoverClassesInModule |
15 """Discover all modules in |start_dir| which match |pattern|. | |
16 | |
17 Args: | |
18 start_dir: The directory to recursively search. | |
19 top_level_dir: The top level of the package, for importing. | |
20 pattern: Unix shell-style pattern for filtering the filenames to import. | |
21 | |
22 Returns: | |
23 list of modules. | |
24 """ | |
25 # start_dir and top_level_dir must be consistent with each other. | |
26 start_dir = os.path.realpath(start_dir) | |
27 top_level_dir = os.path.realpath(top_level_dir) | |
28 | |
29 modules = [] | |
30 sub_paths = list(os.walk(start_dir)) | |
31 # We sort the directories & file paths to ensure a deterministic ordering when | |
32 # traversing |top_level_dir|. | |
33 sub_paths.sort(key=lambda paths_tuple: paths_tuple[0]) | |
34 for dir_path, _, filenames in sub_paths: | |
35 # Sort the directories to walk recursively by the directory path. | |
36 filenames.sort() | |
37 for filename in filenames: | |
38 # Filter out unwanted filenames. | |
39 if filename.startswith('.') or filename.startswith('_'): | |
40 continue | |
41 if os.path.splitext(filename)[1] != '.py': | |
42 continue | |
43 if not fnmatch.fnmatch(filename, pattern): | |
44 continue | |
45 | |
46 # Find the module. | |
47 module_rel_path = os.path.relpath( | |
48 os.path.join(dir_path, filename), top_level_dir) | |
49 module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0]) | |
50 | |
51 # Import the module. | |
52 try: | |
53 # Make sure that top_level_dir is the first path in the sys.path in case | |
54 # there are naming conflict in module parts. | |
55 original_sys_path = sys.path[:] | |
56 sys.path.insert(0, top_level_dir) | |
57 module = __import__(module_name, fromlist=[True]) | |
58 modules.append(module) | |
59 finally: | |
60 sys.path = original_sys_path | |
61 return modules | |
62 | |
63 | |
64 def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2): | |
65 for k in classes_by_key_1: | |
66 if k in classes_by_key_2: | |
67 assert classes_by_key_1[k] is classes_by_key_2[k], ( | |
68 'Found conflicting classes for the same key: ' | |
69 'key=%s, class_1=%s, class_2=%s' % ( | |
70 k, classes_by_key_1[k], classes_by_key_2[k])) | |
71 | |
72 | |
73 # TODO(dtu): Normalize all discoverable classes to have corresponding module | |
74 # and class names, then always index by class name. | |
75 def DiscoverClasses(start_dir, | |
76 top_level_dir, | |
77 base_class, | |
78 pattern='*', | |
79 index_by_class_name=True, | |
80 directly_constructable=False): | |
81 """Discover all classes in |start_dir| which subclass |base_class|. | |
82 | |
83 Base classes that contain subclasses are ignored by default. | |
84 | |
85 Args: | |
86 start_dir: The directory to recursively search. | |
87 top_level_dir: The top level of the package, for importing. | |
88 base_class: The base class to search for. | |
89 pattern: Unix shell-style pattern for filtering the filenames to import. | |
90 index_by_class_name: If True, use class name converted to | |
91 lowercase_with_underscores instead of module name in return dict keys. | |
92 directly_constructable: If True, will only return classes that can be | |
93 constructed without arguments | |
94 | |
95 Returns: | |
96 dict of {module_name: class} or {underscored_class_name: class} | |
97 """ | |
98 modules = DiscoverModules(start_dir, top_level_dir, pattern) | |
99 classes = {} | |
100 for module in modules: | |
101 new_classes = DiscoverClassesInModule( | |
102 module, base_class, index_by_class_name, directly_constructable) | |
103 # TODO(nednguyen): we should remove index_by_class_name once | |
104 # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied | |
105 # naming collisions to reduce the number of smoked benchmark tests. | |
106 # crbug.com/548652 | |
107 if index_by_class_name: | |
108 AssertNoKeyConflicts(classes, new_classes) | |
109 classes = dict(classes.items() + new_classes.items()) | |
110 return classes | |
111 | |
112 | |
113 # TODO(nednguyen): we should remove index_by_class_name once | |
114 # benchmark_smoke_unittest in chromium/src/tools/perf no longer relied | |
115 # naming collisions to reduce the number of smoked benchmark tests. | |
116 # crbug.com/548652 | |
117 def DiscoverClassesInModule(module, | |
118 base_class, | |
119 index_by_class_name=False, | |
120 directly_constructable=False): | |
121 """Discover all classes in |module| which subclass |base_class|. | |
122 | |
123 Base classes that contain subclasses are ignored by default. | |
124 | |
125 Args: | |
126 module: The module to search. | |
127 base_class: The base class to search for. | |
128 index_by_class_name: If True, use class name converted to | |
129 lowercase_with_underscores instead of module name in return dict keys. | |
130 | |
131 Returns: | |
132 dict of {module_name: class} or {underscored_class_name: class} | |
133 """ | |
134 classes = {} | |
135 for _, obj in inspect.getmembers(module): | |
136 # Ensure object is a class. | |
137 if not inspect.isclass(obj): | |
138 continue | |
139 # Include only subclasses of base_class. | |
140 if not issubclass(obj, base_class): | |
141 continue | |
142 # Exclude the base_class itself. | |
143 if obj is base_class: | |
144 continue | |
145 # Exclude protected or private classes. | |
146 if obj.__name__.startswith('_'): | |
147 continue | |
148 # Include only the module in which the class is defined. | |
149 # If a class is imported by another module, exclude those duplicates. | |
150 if obj.__module__ != module.__name__: | |
151 continue | |
152 | |
153 if index_by_class_name: | |
154 key_name = camel_case.ToUnderscore(obj.__name__) | |
155 else: | |
156 key_name = module.__name__.split('.')[-1] | |
157 if not directly_constructable or IsDirectlyConstructable(obj): | |
158 if key_name in classes and index_by_class_name: | |
159 assert classes[key_name] is obj, ( | |
160 'Duplicate key_name with different objs detected: ' | |
161 'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj)) | |
162 else: | |
163 classes[key_name] = obj | |
164 | |
165 return classes | |
166 | |
167 | |
168 def IsDirectlyConstructable(cls): | |
169 """Returns True if instance of |cls| can be construct without arguments.""" | |
170 assert inspect.isclass(cls) | |
171 if not hasattr(cls, '__init__'): | |
172 # Case |class A: pass|. | |
173 return True | |
174 if cls.__init__ is object.__init__: | |
175 # Case |class A(object): pass|. | |
176 return True | |
177 # Case |class (object):| with |__init__| other than |object.__init__|. | |
178 args, _, _, defaults = inspect.getargspec(cls.__init__) | |
179 if defaults is None: | |
180 defaults = () | |
181 # Return true if |self| is only arg without a default. | |
182 return len(args) == len(defaults) + 1 | |
183 | |
184 | |
185 _counter = [0] | |
186 | |
187 | |
188 def _GetUniqueModuleName(): | |
189 _counter[0] += 1 | |
190 return "module_" + str(_counter[0]) | |
OLD | NEW |