Skip to content

Commit 2c35864

Browse files
committed
test: add tests for Fal.AI, ForgeModules and fix existing tests
- Add tests for FalAiEndpointRepositoryImpl, FalAiGenerationRepositoryImpl, ForgeModulesRepositoryImpl - Add tests for TestFalAiApiKeyUseCaseImpl, GetForgeModulesUseCaseImpl, GetGalleryPagedIdsUseCaseImpl - Add tests for FalAiGenerationUseCaseImpl, ConnectToFalAiUseCaseImpl - Fix existing tests: add mediaFileManager, falAiApiKey, mediaPath/inputMediaPath mocks - Add mock files for FalAi and ForgeModule entities
1 parent a8df743 commit 2c35864

28 files changed

Lines changed: 1370 additions & 12 deletions

data/src/test/java/com/shifthackz/aisdv1/data/mocks/AiGenerationResultMocks.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.shifthackz.aisdv1.data.mocks
22

33
import com.shifthackz.aisdv1.domain.entity.AiGenerationResult
4+
import com.shifthackz.aisdv1.domain.entity.MediaType
45
import java.util.Date
56

67
val mockAiGenerationResult = AiGenerationResult(
@@ -22,6 +23,9 @@ val mockAiGenerationResult = AiGenerationResult(
2223
subSeedStrength = 5598f,
2324
denoisingStrength = 1504f,
2425
hidden = false,
26+
mediaPath = "",
27+
inputMediaPath = "",
28+
mediaType = MediaType.IMAGE,
2529
)
2630

2731
val mockAiGenerationResults = listOf(mockAiGenerationResult)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package com.shifthackz.aisdv1.data.mocks
2+
3+
import com.shifthackz.aisdv1.domain.entity.FalAiEndpoint
4+
import com.shifthackz.aisdv1.domain.entity.FalAiEndpointCategory
5+
import com.shifthackz.aisdv1.domain.entity.FalAiEndpointSchema
6+
import com.shifthackz.aisdv1.domain.entity.FalAiInputProperty
7+
import com.shifthackz.aisdv1.domain.entity.FalAiPayload
8+
import com.shifthackz.aisdv1.domain.entity.FalAiPropertyType
9+
10+
val mockFalAiInputProperty = FalAiInputProperty(
11+
name = "prompt",
12+
title = "Prompt",
13+
description = "The prompt to generate an image from",
14+
type = FalAiPropertyType.STRING,
15+
default = null,
16+
minimum = null,
17+
maximum = null,
18+
enumValues = null,
19+
isRequired = true,
20+
isImageInput = false,
21+
)
22+
23+
val mockFalAiEndpointSchema = FalAiEndpointSchema(
24+
baseUrl = "https://queue.fal.run",
25+
submissionPath = "/fal-ai/flux/schnell",
26+
inputProperties = listOf(mockFalAiInputProperty),
27+
requiredProperties = listOf("prompt"),
28+
propertyOrder = listOf("prompt"),
29+
)
30+
31+
val mockFalAiEndpoint = FalAiEndpoint(
32+
id = "fal-ai/flux/schnell",
33+
endpointId = "fal-ai/flux/schnell",
34+
title = "FLUX.1 [schnell]",
35+
description = "Fast text to image generation",
36+
category = FalAiEndpointCategory.TEXT_TO_IMAGE,
37+
group = "FLUX",
38+
thumbnailUrl = "https://fal.ai/thumbnails/flux-schnell.jpg",
39+
playgroundUrl = "https://fal.ai/models/fal-ai/flux/schnell",
40+
documentationUrl = "https://fal.ai/models/fal-ai/flux/schnell/api",
41+
isCustom = false,
42+
schema = mockFalAiEndpointSchema,
43+
)
44+
45+
val mockFalAiEndpoints = listOf(mockFalAiEndpoint)
46+
47+
val mockFalAiPayload = FalAiPayload(
48+
endpointId = "fal-ai/flux/schnell",
49+
parameters = mapOf(
50+
"prompt" to "a beautiful sunset",
51+
"num_inference_steps" to 4,
52+
),
53+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.shifthackz.aisdv1.data.mocks
2+
3+
import com.shifthackz.aisdv1.domain.entity.ForgeModule
4+
5+
val mockForgeModule = ForgeModule(
6+
name = "ADetailer",
7+
path = "extensions/adetailer",
8+
)
9+
10+
val mockForgeModules = listOf(
11+
mockForgeModule,
12+
ForgeModule(
13+
name = "ControlNet",
14+
path = "extensions/sd-webui-controlnet",
15+
),
16+
)

data/src/test/java/com/shifthackz/aisdv1/data/mocks/GenerationResultEntityMocks.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ val mockGenerationResultEntity = GenerationResultEntity(
2323
subSeedStrength = 5598f,
2424
denoisingStrength = 1504f,
2525
hidden = false,
26+
mediaPath = "",
27+
inputMediaPath = "",
28+
mediaType = "IMAGE",
2629
)
2730

2831
val mockGenerationResultEntities = listOf(mockGenerationResultEntity)
Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
package com.shifthackz.aisdv1.data.repository
2+
3+
import com.shifthackz.aisdv1.data.mocks.mockFalAiEndpoint
4+
import com.shifthackz.aisdv1.data.mocks.mockFalAiEndpoints
5+
import com.shifthackz.aisdv1.domain.datasource.FalAiEndpointDataSource
6+
import com.shifthackz.aisdv1.domain.preference.PreferenceManager
7+
import io.mockk.every
8+
import io.mockk.mockk
9+
import io.mockk.verify
10+
import io.reactivex.rxjava3.core.Completable
11+
import io.reactivex.rxjava3.core.Observable
12+
import io.reactivex.rxjava3.core.Single
13+
import org.junit.Test
14+
15+
class FalAiEndpointRepositoryImplTest {
16+
17+
private val stubException = Throwable("Something went wrong.")
18+
private val stubBuiltInDataSource = mockk<FalAiEndpointDataSource.BuiltIn>()
19+
private val stubRemoteDataSource = mockk<FalAiEndpointDataSource.Remote>()
20+
private val stubLocalDataSource = mockk<FalAiEndpointDataSource.Local>()
21+
private val stubPreferenceManager = mockk<PreferenceManager>(relaxed = true)
22+
23+
private val repository = FalAiEndpointRepositoryImpl(
24+
builtInDataSource = stubBuiltInDataSource,
25+
remoteDataSource = stubRemoteDataSource,
26+
localDataSource = stubLocalDataSource,
27+
preferenceManager = stubPreferenceManager,
28+
)
29+
30+
@Test
31+
fun `given attempt to observe all endpoints, both sources return data, expected combined list`() {
32+
val customEndpoint = mockFalAiEndpoint.copy(id = "custom-endpoint", isCustom = true)
33+
34+
every {
35+
stubBuiltInDataSource.getAll()
36+
} returns Single.just(mockFalAiEndpoints)
37+
38+
every {
39+
stubLocalDataSource.observeAll()
40+
} returns Observable.just(listOf(customEndpoint))
41+
42+
repository
43+
.observeAll()
44+
.test()
45+
.assertNoErrors()
46+
.assertValueAt(1) { it.size == 2 && it.containsAll(mockFalAiEndpoints + customEndpoint) }
47+
.dispose()
48+
}
49+
50+
@Test
51+
fun `given attempt to get all endpoints, both sources return data, expected combined list`() {
52+
val customEndpoint = mockFalAiEndpoint.copy(id = "custom-endpoint", isCustom = true)
53+
54+
every {
55+
stubBuiltInDataSource.getAll()
56+
} returns Single.just(mockFalAiEndpoints)
57+
58+
every {
59+
stubLocalDataSource.getAll()
60+
} returns Single.just(listOf(customEndpoint))
61+
62+
repository
63+
.getAll()
64+
.test()
65+
.assertNoErrors()
66+
.assertValue { it.size == 2 && it.containsAll(mockFalAiEndpoints + customEndpoint) }
67+
.await()
68+
.assertComplete()
69+
}
70+
71+
@Test
72+
fun `given attempt to get all endpoints, local source fails, expected built-in endpoints only`() {
73+
every {
74+
stubBuiltInDataSource.getAll()
75+
} returns Single.just(mockFalAiEndpoints)
76+
77+
every {
78+
stubLocalDataSource.getAll()
79+
} returns Single.error(stubException)
80+
81+
repository
82+
.getAll()
83+
.test()
84+
.assertNoErrors()
85+
.assertValue(mockFalAiEndpoints)
86+
.await()
87+
.assertComplete()
88+
}
89+
90+
@Test
91+
fun `given attempt to get endpoint by id, endpoint exists, expected valid endpoint`() {
92+
every {
93+
stubBuiltInDataSource.getAll()
94+
} returns Single.just(mockFalAiEndpoints)
95+
96+
every {
97+
stubLocalDataSource.getAll()
98+
} returns Single.just(emptyList())
99+
100+
repository
101+
.getById(mockFalAiEndpoint.id)
102+
.test()
103+
.assertNoErrors()
104+
.assertValue(mockFalAiEndpoint)
105+
.await()
106+
.assertComplete()
107+
}
108+
109+
@Test
110+
fun `given attempt to get endpoint by id, endpoint not found, expected error`() {
111+
every {
112+
stubBuiltInDataSource.getAll()
113+
} returns Single.just(mockFalAiEndpoints)
114+
115+
every {
116+
stubLocalDataSource.getAll()
117+
} returns Single.just(emptyList())
118+
119+
repository
120+
.getById("non-existent-id")
121+
.test()
122+
.assertError { it is NoSuchElementException }
123+
.assertNoValues()
124+
.await()
125+
.assertNotComplete()
126+
}
127+
128+
@Test
129+
fun `given attempt to get selected endpoint, preference has valid id, expected valid endpoint`() {
130+
every {
131+
stubPreferenceManager.falAiSelectedEndpointId
132+
} returns mockFalAiEndpoint.id
133+
134+
every {
135+
stubBuiltInDataSource.getAll()
136+
} returns Single.just(mockFalAiEndpoints)
137+
138+
every {
139+
stubLocalDataSource.getAll()
140+
} returns Single.just(emptyList())
141+
142+
repository
143+
.getSelected()
144+
.test()
145+
.assertNoErrors()
146+
.assertValue(mockFalAiEndpoint)
147+
.await()
148+
.assertComplete()
149+
}
150+
151+
@Test
152+
fun `given attempt to get selected endpoint, preference is blank, expected first built-in endpoint`() {
153+
every {
154+
stubPreferenceManager.falAiSelectedEndpointId
155+
} returns ""
156+
157+
every {
158+
stubBuiltInDataSource.getAll()
159+
} returns Single.just(mockFalAiEndpoints)
160+
161+
repository
162+
.getSelected()
163+
.test()
164+
.assertNoErrors()
165+
.assertValue(mockFalAiEndpoints.first())
166+
.await()
167+
.assertComplete()
168+
}
169+
170+
@Test
171+
fun `given attempt to set selected endpoint, expected preference updated`() {
172+
every {
173+
stubPreferenceManager.falAiSelectedEndpointId = any()
174+
} returns Unit
175+
176+
repository
177+
.setSelected("new-endpoint-id")
178+
.test()
179+
.assertNoErrors()
180+
.await()
181+
.assertComplete()
182+
183+
verify { stubPreferenceManager.falAiSelectedEndpointId = "new-endpoint-id" }
184+
}
185+
186+
@Test
187+
fun `given attempt to import from url, remote returns endpoint, expected endpoint saved and returned`() {
188+
val importedEndpoint = mockFalAiEndpoint.copy(id = "imported", isCustom = true)
189+
190+
every {
191+
stubRemoteDataSource.fetchFromUrl(any())
192+
} returns Single.just(importedEndpoint)
193+
194+
every {
195+
stubLocalDataSource.save(any())
196+
} returns Completable.complete()
197+
198+
repository
199+
.importFromUrl("https://fal.ai/api/openapi/queue/openapi.json?endpoint_id=test")
200+
.test()
201+
.assertNoErrors()
202+
.assertValue(importedEndpoint)
203+
.await()
204+
.assertComplete()
205+
206+
verify { stubLocalDataSource.save(importedEndpoint) }
207+
}
208+
209+
@Test
210+
fun `given attempt to import from url, remote fails, expected error`() {
211+
every {
212+
stubRemoteDataSource.fetchFromUrl(any())
213+
} returns Single.error(stubException)
214+
215+
repository
216+
.importFromUrl("https://fal.ai/api/invalid")
217+
.test()
218+
.assertError(stubException)
219+
.assertNoValues()
220+
.await()
221+
.assertNotComplete()
222+
}
223+
224+
@Test
225+
fun `given attempt to delete endpoint, expected delete called on local source`() {
226+
every {
227+
stubLocalDataSource.delete(any())
228+
} returns Completable.complete()
229+
230+
repository
231+
.delete("endpoint-to-delete")
232+
.test()
233+
.assertNoErrors()
234+
.await()
235+
.assertComplete()
236+
237+
verify { stubLocalDataSource.delete("endpoint-to-delete") }
238+
}
239+
240+
@Test
241+
fun `given attempt to delete endpoint, local source fails, expected error`() {
242+
every {
243+
stubLocalDataSource.delete(any())
244+
} returns Completable.error(stubException)
245+
246+
repository
247+
.delete("endpoint-to-delete")
248+
.test()
249+
.assertError(stubException)
250+
.await()
251+
.assertNotComplete()
252+
}
253+
}

0 commit comments

Comments
 (0)