Skip to content

Commit a3ea95a

Browse files
committed
Big restructuration of cohere client
1 parent cf4be1d commit a3ea95a

10 files changed

Lines changed: 1071 additions & 685 deletions

File tree

Directory.Packages.props

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<PackageVersion Include="Azure.AI.OpenAI" Version="1.0.0-beta.17" />
88
<PackageVersion Include="CommandDotNet.Spectre" Version="3.0.2" />
99
<PackageVersion Include="Microsoft.Extensions.Http" Version="8.0.0" />
10-
<PackageVersion Include="Microsoft.Extensions.Http.Resilience" Version="8.5.0" />
10+
<PackageVersion Include="Microsoft.Extensions.Http.Resilience" Version="8.6.0" />
1111
<PackageVersion Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
1212
<PackageVersion Include="Microsoft.Extensions.Logging.Debug" Version="8.0.0" />
1313
<PackageVersion Include="Microsoft.KernelMemory.Abstractions" Version="0.61.240524.1" />
@@ -18,6 +18,7 @@
1818
<PackageVersion Include="Microsoft.SemanticKernel.Yaml" Version="1.13.0" />
1919
<PackageVersion Include="Microsoft.SemanticKernel.Abstractions" Version="1.13.0" />
2020
<PackageVersion Include="Microsoft.SemanticKernel.Core" Version="1.13.0" />
21+
<PackageVersion Include="Polly.Core" Version="8.4.1" />
2122
<PackageVersion Include="TiktokenSharp" Version="1.1.4" />
2223
<PackageVersion Include="Microsoft.SourceLink.GitHub" Version="8.0.0" />
2324
</ItemGroup>

src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,46 @@ public async Task RunSample2()
2727
{
2828
var services = new ServiceCollection();
2929

30-
CohereConfiguration cohereConfiguration = new CohereConfiguration();
31-
cohereConfiguration.ApiKey = Dotenv.Get("COHERE_API_KEY");
30+
var apiKey = Dotenv.Get("COHERE_API_KEY")!;
31+
var cohereBaseUrl = Dotenv.Get("COHERE_BASE_API_KEY");
32+
if (string.IsNullOrEmpty(cohereBaseUrl))
33+
{
34+
services.ConfigureCohereChat(apiKey);
35+
}
36+
else
37+
{
38+
services.ConfigureCohereChat(apiKey, cohereBaseUrl);
39+
}
40+
//verify if rerank has a different api key (because the apikey point on azure ai studio)
41+
var rerankApiKey = Dotenv.Get("COHERE_RERANK_API_KEY");
42+
if (string.IsNullOrEmpty(rerankApiKey))
43+
{
44+
services.ConfigureCohereRerank(apiKey);
45+
}
46+
else
47+
{
48+
services.ConfigureCohereRerank(rerankApiKey);
49+
}
50+
51+
services.AddHttpClient<RawCohereChatClient>()
52+
.AddStandardResilienceHandler(options =>
53+
{
54+
// Configure standard resilience options here
55+
});
56+
services.AddHttpClient<RawCohereReRankerClient>()
57+
.AddStandardResilienceHandler(options =>
58+
{
59+
// Configure standard resilience options here
60+
});
61+
services.AddHttpClient<RawCohereEmbeddingClient>()
62+
.AddStandardResilienceHandler(options =>
63+
{
64+
// Configure standard resilience options here
65+
});
3266

3367
CohereCommandRQueryExecutorConfiguration coereCommandRagQueryExecutorConfiguration = new();
3468
coereCommandRagQueryExecutorConfiguration.MaxMemoryRecord = 10;
3569

36-
services.AddSingleton(cohereConfiguration);
3770
services.AddSingleton(coereCommandRagQueryExecutorConfiguration);
3871
services.AddSingleton<RawCohereClient>();
3972
services.AddSingleton<CohereCommandRQueryExecutor>();

src/KernelMemory.Extensions.FunctionalTests/Cohere/CohereTests.cs

Lines changed: 91 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,48 @@
66

77
namespace KernelMemory.Extensions.FunctionalTests.Cohere;
88

9-
public class CohereReRankTests
9+
public class CohereTests
1010
{
11-
private IHttpClientFactory _ihttpClientFactory;
11+
private ServiceProvider _serviceProvider;
1212

13-
public CohereReRankTests()
13+
private IHttpClientFactory _httpClientFactory;
14+
15+
public CohereTests()
1416
{
1517
var services = new ServiceCollection();
16-
services.AddHttpClient();
17-
var serviceProvider = services.BuildServiceProvider();
18-
_ihttpClientFactory = serviceProvider.GetRequiredService<IHttpClientFactory>();
18+
services.AddHttpClient<RawCohereChatClient>()
19+
.AddStandardResilienceHandler(options =>
20+
{
21+
// Configure standard resilience options here
22+
});
23+
services.AddHttpClient<RawCohereReRankerClient>()
24+
.AddStandardResilienceHandler(options =>
25+
{
26+
// Configure standard resilience options here
27+
});
28+
services.AddHttpClient<RawCohereEmbeddingClient>()
29+
.AddStandardResilienceHandler(options =>
30+
{
31+
// Configure standard resilience options here
32+
});
33+
34+
var cohereApiKey = Environment.GetEnvironmentVariable("COHERE_API_KEY");
35+
36+
if (string.IsNullOrEmpty(cohereApiKey))
37+
{
38+
throw new Exception("COHERE_API_KEY is not set");
39+
}
40+
41+
services.ConfigureCohere(cohereApiKey);
42+
43+
_serviceProvider = services.BuildServiceProvider();
44+
_httpClientFactory = _serviceProvider.GetRequiredService<IHttpClientFactory>();
1945
}
2046

2147
[Fact]
2248
public async Task Basic_cohere_reranking()
2349
{
24-
CohereConfiguration cohereConfig = CreateConfig();
25-
var cohereClient = new RawCohereClient(cohereConfig, _ihttpClientFactory);
50+
var cohereClient = new RawCohereClient(_serviceProvider);
2651
var ReRankResult = await cohereClient.ReRankAsync(new CohereReRankRequest("What is the capital of the United States?",
2752
["Carson City is the capital city of the American state of Nevada.",
2853
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
@@ -36,8 +61,7 @@ public async Task Basic_cohere_reranking()
3661
[Fact]
3762
public async Task Can_rerank_empty_document_list()
3863
{
39-
CohereConfiguration cohereConfig = CreateConfig();
40-
var cohereClient = new RawCohereClient(cohereConfig, _ihttpClientFactory);
64+
var cohereClient = new RawCohereClient(_serviceProvider);
4165
var ReRankResult = await cohereClient.ReRankAsync(new CohereReRankRequest("What is the capital of the United States?", []));
4266

4367
Assert.NotNull(ReRankResult);
@@ -47,8 +71,7 @@ public async Task Can_rerank_empty_document_list()
4771
[Fact]
4872
public async Task Basic_cohere_Rag_streaming()
4973
{
50-
CohereConfiguration cohereConfig = CreateConfig();
51-
var cohereClient = new RawCohereClient(cohereConfig, _ihttpClientFactory);
74+
var cohereClient = new RawCohereClient(_serviceProvider);
5275

5376
var records = new List<MemoryRecord>();
5477
records.Add(MemoryRecordTestUtilities.CreateMemoryRecord("doc1", "file1", 1, "Carson City is the capital city of the American state of Nevada."));
@@ -65,8 +88,7 @@ public async Task Basic_cohere_Rag_streaming()
6588
[Fact]
6689
public async Task Basic_cohere_Rag()
6790
{
68-
CohereConfiguration cohereConfig = CreateConfig();
69-
var cohereClient = new RawCohereClient(cohereConfig, _ihttpClientFactory);
91+
var cohereClient = new RawCohereClient(_serviceProvider);
7092

7193
var records = new List<MemoryRecord>();
7294
records.Add(MemoryRecordTestUtilities.CreateMemoryRecord("doc1", "file1", 1, "Carson City is the capital city of the American state of Nevada."));
@@ -86,8 +108,7 @@ public async Task Basic_cohere_Rag()
86108
[Fact]
87109
public async Task Basic_cohere_embed_test()
88110
{
89-
CohereConfiguration cohereConfig = CreateConfig();
90-
var cohereClient = new RawCohereClient(cohereConfig, _ihttpClientFactory);
111+
var cohereClient = new RawCohereClient(_serviceProvider);
91112

92113
var embedRequest = new CohereEmbedRequest
93114
{
@@ -106,76 +127,62 @@ public async Task Basic_cohere_embed_test()
106127
[Fact]
107128
public void Tokenizer_raw_test()
108129
{
109-
CohereTokenizer tokenizer = new(_ihttpClientFactory);
130+
CohereTokenizer tokenizer = new(_httpClientFactory);
110131
var count = tokenizer.CountToken("command-r-plus", "Now I'm using CommandR+ tokenizer");
111132
Assert.Equal(8, count);
112133
}
113134

114-
/// <summary>
115-
/// In azure ai studio we do not still have re-ranking, so we need to use re-ranker with a configuration
116-
/// and the executor with another configuration
117-
/// </summary>
118-
[Fact]
119-
public void Ability_to_use_azure()
120-
{
121-
var configReRank = new CohereConfiguration()
122-
{
123-
ApiKey = "Base Api Key",
124-
};
125-
var configRagExecutor = new CohereConfiguration()
126-
{
127-
ApiKey = "Azure configuration",
128-
BaseUrl = "https://api.azure.cohere.ai/",
129-
};
130-
131-
var services = new ServiceCollection();
132-
services.AddHttpClient();
133-
services.AddKeyedSingleton<CohereConfiguration>("rerank", configReRank);
134-
services.AddKeyedSingleton<CohereConfiguration>("executor", configRagExecutor);
135-
136-
services.AddKeyedSingleton<RawCohereClient>("rerank", (sp, key) =>
137-
{
138-
var options = sp.GetKeyedService<CohereConfiguration>(key);
139-
var httpClientFactory = sp.GetRequiredService<IHttpClientFactory>();
140-
return new RawCohereClient(options, httpClientFactory);
141-
});
142-
services.AddKeyedSingleton<RawCohereClient>("executor", (sp, key) =>
143-
{
144-
var options = sp.GetKeyedService<CohereConfiguration>(key);
145-
var httpClientFactory = sp.GetRequiredService<IHttpClientFactory>();
146-
return new RawCohereClient(options, httpClientFactory);
147-
});
148-
149-
var serviceProvider = services.BuildServiceProvider();
150-
var cohereConfigurationReRank = serviceProvider.GetKeyedService<CohereConfiguration>("rerank");
151-
var cohereConfigurationExecutor = serviceProvider.GetKeyedService<CohereConfiguration>("executor");
152-
153-
var cohereClientReRank = serviceProvider.GetKeyedService<RawCohereClient>("rerank");
154-
var cohereClientExecutor = serviceProvider.GetKeyedService<RawCohereClient>("executor");
155-
156-
//Base assertion, you can simply get by key
157-
Assert.Equal("Azure configuration", cohereConfigurationExecutor.ApiKey);
158-
Assert.Equal("https://api.azure.cohere.ai/", cohereConfigurationExecutor.BaseUrl);
159-
160-
Assert.Equal("https://api.cohere.ai/", cohereConfigurationReRank.BaseUrl);
161-
Assert.Equal("Base Api Key", cohereConfigurationReRank.ApiKey);
162-
163-
//now verify the two clients
164-
Assert.Equal(cohereClientReRank.GetFieldValue("_apiKey"), cohereConfigurationReRank.ApiKey);
165-
166-
}
167-
168-
private static CohereConfiguration CreateConfig()
169-
{
170-
var cohereConfig = new CohereConfiguration
171-
{
172-
ApiKey = Environment.GetEnvironmentVariable("COHERE_API_KEY"),
173-
};
174-
if (string.IsNullOrEmpty(cohereConfig.ApiKey))
175-
{
176-
throw new Exception("COHERE_API_KEY is not set");
177-
}
178-
179-
return cohereConfig;
180-
}
135+
// /// <summary>
136+
// /// In azure ai studio we do not still have re-ranking, so we need to use re-ranker with a configuration
137+
// /// and the executor with another configuration
138+
// /// </summary>
139+
// [Fact]
140+
// public void Ability_to_use_azure()
141+
// {
142+
// var configReRank = new CohereConfiguration()
143+
// {
144+
// ApiKey = "Base Api Key",
145+
// };
146+
// var configRagExecutor = new CohereConfiguration()
147+
// {
148+
// ApiKey = "Azure configuration",
149+
// BaseUrl = "https://api.azure.cohere.ai/",
150+
// };
151+
152+
// var services = new ServiceCollection();
153+
// services.AddHttpClient();
154+
// services.AddKeyedSingleton<CohereConfiguration>("rerank", configReRank);
155+
// services.AddKeyedSingleton<CohereConfiguration>("executor", configRagExecutor);
156+
157+
// services.AddKeyedSingleton<RawCohereClient>("rerank", (sp, key) =>
158+
// {
159+
// var options = sp.GetKeyedService<CohereConfiguration>(key);
160+
// var httpClientFactory = sp.GetRequiredService<IHttpClientFactory>();
161+
// return new RawCohereClient(options, _serviceProvider);
162+
// });
163+
// services.AddKeyedSingleton<RawCohereClient>("executor", (sp, key) =>
164+
// {
165+
// var options = sp.GetKeyedService<CohereConfiguration>(key);
166+
// var httpClientFactory = sp.GetRequiredService<IHttpClientFactory>();
167+
// return new RawCohereClient(options, _serviceProvider);
168+
// });
169+
170+
// var serviceProvider = services.BuildServiceProvider();
171+
// var cohereConfigurationReRank = serviceProvider.GetKeyedService<CohereConfiguration>("rerank");
172+
// var cohereConfigurationExecutor = serviceProvider.GetKeyedService<CohereConfiguration>("executor");
173+
174+
// var cohereClientReRank = serviceProvider.GetKeyedService<RawCohereClient>("rerank");
175+
// var cohereClientExecutor = serviceProvider.GetKeyedService<RawCohereClient>("executor");
176+
177+
// //Base assertion, you can simply get by key
178+
// Assert.Equal("Azure configuration", cohereConfigurationExecutor.ApiKey);
179+
// Assert.Equal("https://api.azure.cohere.ai/", cohereConfigurationExecutor.BaseUrl);
180+
181+
// Assert.Equal("https://api.cohere.ai/", cohereConfigurationReRank.BaseUrl);
182+
// Assert.Equal("Base Api Key", cohereConfigurationReRank.ApiKey);
183+
184+
// //now verify the two clients
185+
// Assert.Equal(cohereClientReRank.GetFieldValue("_apiKey"), cohereConfigurationReRank.ApiKey);
186+
187+
// }
181188
}

0 commit comments

Comments
 (0)