3232import java .util .List ;
3333import java .util .Map ;
3434import java .util .Optional ;
35- import java .util .function .Supplier ;
3635import java .util .stream .Collectors ;
3736import java .util .stream .Stream ;
3837
3938public class ValuePoolRegistry {
4039 private final Method fuzzTestMethod ;
4140 private final Path baseDir ;
42- private final Map <String , Supplier < Stream < ?>>> pools ;
41+ private final Map <Method , List < ?>> supplierValuesCache = new LinkedHashMap <>() ;
4342 private final Map <Path , Optional <byte []>> pathToBytesCache = new LinkedHashMap <>();
4443
4544 public ValuePoolRegistry (Method fuzzTestMethod ) {
@@ -48,7 +47,6 @@ public ValuePoolRegistry(Method fuzzTestMethod) {
4847
4948 protected ValuePoolRegistry (Method fuzzTestMethod , Path baseDir ) {
5049 this .fuzzTestMethod = fuzzTestMethod ;
51- this .pools = extractValueSuppliers (fuzzTestMethod );
5250 this .baseDir = baseDir ;
5351 }
5452
@@ -96,32 +94,143 @@ public Stream<?> extractUserValues(AnnotatedType type) {
9694 .map (ValuePool ::value )
9795 .flatMap (Arrays ::stream )
9896 .filter (name -> !name .isEmpty ())
99- .flatMap (
100- name -> {
101- Supplier <Stream <?>> supplier = pools .get (name );
102- if (supplier == null ) {
103- throw new IllegalStateException (
104- "@ValuePool: No method named '"
105- + name
106- + "' found for type "
107- + type .getType ().getTypeName ()
108- + " in fuzz test method "
109- + fuzzTestMethod .getName ()
110- + ". Available provider methods: "
111- + String .join (", " , pools .keySet ()));
112- }
113- return supplier .get ();
114- })
97+ .map (String ::trim )
98+ .flatMap (this ::loadUserValuesFromSupplier )
11599 .distinct ();
116100
117- // Walking the file system only makes sense for ValuePool's that annotate byte[] types.
101+ // Walking the file system only makes sense for pools that annotate byte[] types.
118102 if (type .getType () == byte [].class ) {
119103 return Stream .concat (valuesFromSourceMethods , extractByteArraysFromPatterns (type ));
120104 } else {
121105 return valuesFromSourceMethods ;
122106 }
123107 }
124108
109+ private Stream <?> loadUserValuesFromSupplier (String supplierRef ) {
110+ Method supplier = resolveSupplier (supplierRef );
111+ return supplierValuesCache
112+ .computeIfAbsent (supplier , s -> loadValuesFromMethod (s , supplierRef ))
113+ .stream ();
114+ }
115+
116+ private Method resolveSupplier (String supplierRef ) {
117+ if (supplierRef .isEmpty ()) {
118+ throw new IllegalArgumentException ("@ValuePool: Supplier method cannot be blank" );
119+ }
120+
121+ int hashIndex = supplierRef .indexOf ('#' );
122+
123+ // Supplier method is in the fuzz test class
124+ if (hashIndex == -1 ) {
125+ return resolveSupplier (fuzzTestMethod .getDeclaringClass (), supplierRef );
126+ }
127+
128+ // Supplier method is not in the fuzz test class
129+ // Validate the format of the supplier reference before loading the class
130+ if (hashIndex != supplierRef .lastIndexOf ('#' )) {
131+ throw new IllegalArgumentException (
132+ "@ValuePool: Invalid supplier method reference (multiple '#'): " + supplierRef );
133+ }
134+ if (hashIndex == 0 || hashIndex == supplierRef .length () - 1 ) {
135+ throw new IllegalArgumentException (
136+ "@ValuePool: Invalid supplier method reference (expected 'ClassName#methodName'): "
137+ + supplierRef );
138+ }
139+
140+ String className = supplierRef .substring (0 , hashIndex );
141+ String methodName = supplierRef .substring (hashIndex + 1 );
142+ if (className .isEmpty () || methodName .isEmpty ()) {
143+ throw new IllegalArgumentException (
144+ "@ValuePool: Invalid supplier method reference (expected 'ClassName#methodName'): "
145+ + supplierRef );
146+ }
147+
148+ Class <?> clazz = loadClass (className );
149+ return resolveSupplier (clazz , methodName );
150+ }
151+
152+ private Method resolveSupplier (Class <?> clazz , String methodName ) {
153+ try {
154+ return clazz .getDeclaredMethod (methodName );
155+ } catch (NoSuchMethodException e ) {
156+ throw new IllegalArgumentException (
157+ "@ValuePool: No supplier method named '" + methodName + "' found in class " + clazz , e );
158+ }
159+ }
160+
161+ private Class <?> loadClass (String className ) {
162+ ClassLoader fuzzTestLoader = fuzzTestMethod .getDeclaringClass ().getClassLoader ();
163+ try {
164+ return Class .forName (className , false , fuzzTestLoader );
165+ } catch (ClassNotFoundException | LinkageError | SecurityException firstFailure ) {
166+ // Retry with the context class loader
167+ ClassLoader contextLoader = Thread .currentThread ().getContextClassLoader ();
168+ if (contextLoader != null && contextLoader != fuzzTestLoader ) {
169+ try {
170+ return Class .forName (className , false , contextLoader );
171+ } catch (ClassNotFoundException | LinkageError | SecurityException secondFailure ) {
172+ IllegalArgumentException ex =
173+ new IllegalArgumentException (
174+ "@ValuePool: Failed to load class '"
175+ + className
176+ + "' (fuzzTestLoader="
177+ + fuzzTestLoader
178+ + ", contextLoader="
179+ + contextLoader
180+ + ")" ,
181+ firstFailure );
182+ ex .addSuppressed (secondFailure );
183+ throw ex ;
184+ }
185+ }
186+ if (firstFailure instanceof ClassNotFoundException ) {
187+ throw new IllegalArgumentException (
188+ "@ValuePool: No class named '" + className + "' found" , firstFailure );
189+ }
190+ throw new IllegalArgumentException (
191+ "@ValuePool: Failed to load class '"
192+ + className
193+ + "' using class loader "
194+ + fuzzTestLoader ,
195+ firstFailure );
196+ }
197+ }
198+
199+ private List <Object > loadValuesFromMethod (Method supplier , String supplierRef ) {
200+ if (!Modifier .isStatic (supplier .getModifiers ())) {
201+ throw new IllegalStateException (
202+ "@ValuePool: supplier method '"
203+ + supplierRef
204+ + "' must be static in fuzz test method "
205+ + fuzzTestMethod .getName ());
206+ }
207+ if (!Stream .class .equals (supplier .getReturnType ())) {
208+ throw new IllegalStateException (
209+ "@ValuePool: supplier method '"
210+ + supplierRef
211+ + "' must return a Stream<?> in fuzz test method "
212+ + fuzzTestMethod .getName ());
213+ }
214+
215+ supplier .setAccessible (true );
216+
217+ try {
218+ List <Object > values = ((Stream <?>) supplier .invoke (null )).collect (Collectors .toList ());
219+ if (values .isEmpty ()) {
220+ throw new IllegalStateException (
221+ "@ValuePool: supplier method '" + supplierRef + "' returned no values." );
222+ }
223+ return values ;
224+ } catch (IllegalAccessException e ) {
225+ throw new RuntimeException ("@ValuePool: Access denied for supplier method " + supplierRef , e );
226+ } catch (InvocationTargetException e ) {
227+ Throwable cause = e .getCause ();
228+ throw new RuntimeException (
229+ "@ValuePool: Supplier method " + supplierRef + " threw an exception" ,
230+ cause != null ? cause : e );
231+ }
232+ }
233+
125234 private Stream <byte []> extractByteArraysFromPatterns (AnnotatedType type ) {
126235 List <ValuePool > annotations = getValuePoolAnnotations (type );
127236
@@ -170,51 +279,4 @@ private Optional<byte[]> tryReadFile(Path path) {
170279 }
171280 });
172281 }
173-
174- private static Map <String , Supplier <Stream <?>>> extractValueSuppliers (Method fuzzTestMethod ) {
175- return Arrays .stream (fuzzTestMethod .getDeclaringClass ().getDeclaredMethods ())
176- .filter (m -> m .getParameterCount () == 0 )
177- .filter (m -> Stream .class .equals (m .getReturnType ()))
178- .filter (m -> Modifier .isStatic (m .getModifiers ()))
179- .collect (Collectors .toMap (Method ::getName , ValuePoolRegistry ::createLazyStreamSupplier ));
180- }
181-
182- private static Supplier <Stream <?>> createLazyStreamSupplier (Method method ) {
183- return new Supplier <Stream <?>>() {
184- private volatile List <Object > cachedData = null ;
185-
186- @ Override
187- public Stream <?> get () {
188- if (cachedData == null ) {
189- synchronized (this ) {
190- if (cachedData == null ) {
191- cachedData = loadDataFromMethod (method );
192- }
193- }
194- if (cachedData .isEmpty ()) {
195- throw new IllegalStateException (
196- "@ValuePool: method '"
197- + method .getName ()
198- + "' returned no values. Value pool methods must return at least one value." );
199- }
200- }
201- return cachedData .stream ();
202- }
203- };
204- }
205-
206- private static List <Object > loadDataFromMethod (Method method ) {
207- method .setAccessible (true );
208- try {
209- Stream <?> stream = (Stream <?>) method .invoke (null );
210- return stream .collect (Collectors .toList ());
211- } catch (IllegalAccessException e ) {
212- throw new RuntimeException ("@ValuePool: Access denied for method " + method .getName (), e );
213- } catch (InvocationTargetException e ) {
214- Throwable cause = e .getCause ();
215- throw new RuntimeException (
216- "@ValuePool: Method " + method .getName () + " threw an exception" ,
217- cause != null ? cause : e );
218- }
219- }
220282}
0 commit comments