11using System . Collections . Generic ;
22using System . Globalization ;
3+ using System . Linq ;
34using System . Runtime . CompilerServices ;
4- using System . Text ;
55using System . Text . Json ;
66using GraphRag . Graphs ;
77using Microsoft . Extensions . Logging ;
88using Npgsql ;
9+ using NpgsqlTypes ;
910
1011namespace GraphRag . Storage . Postgres ;
1112
12- public sealed class PostgresGraphStore : IGraphStore
13+ public class PostgresGraphStore : IGraphStore , IAsyncDisposable
1314{
1415 private readonly string _connectionString ;
1516 private readonly string _graphName ;
@@ -50,8 +51,14 @@ public async Task UpsertNodeAsync(string id, string label, IReadOnlyDictionary<s
5051 ArgumentException . ThrowIfNullOrWhiteSpace ( label ) ;
5152 ArgumentNullException . ThrowIfNull ( properties ) ;
5253
53- var cypher = BuildNodeUpsertCypher ( id , label , properties ) ;
54- await ExecuteCypherAsync ( cypher , cancellationToken ) . ConfigureAwait ( false ) ;
54+ var query = $ "MERGE (n:{ EscapeLabel ( label ) } {{ id: $node_id }}) SET n += $props RETURN n";
55+ var parameters = new Dictionary < string , object ? >
56+ {
57+ [ "node_id" ] = id ,
58+ [ "props" ] = ConvertProperties ( properties )
59+ } ;
60+
61+ await ExecuteCypherAsync ( query , parameters , cancellationToken ) . ConfigureAwait ( false ) ;
5562 _logger . LogDebug ( "Upserted node {Id} ({Label}) into graph {GraphName}." , id , label , _graphName ) ;
5663 }
5764
@@ -62,8 +69,15 @@ public async Task UpsertRelationshipAsync(string sourceId, string targetId, stri
6269 ArgumentException . ThrowIfNullOrWhiteSpace ( type ) ;
6370 ArgumentNullException . ThrowIfNull ( properties ) ;
6471
65- var cypher = BuildRelationshipUpsertCypher ( sourceId , targetId , type , properties ) ;
66- await ExecuteCypherAsync ( cypher , cancellationToken ) . ConfigureAwait ( false ) ;
72+ var query = $ "MATCH (source {{ id: $source_id }}), (target {{ id: $target_id }}) MERGE (source)-[rel:{ EscapeLabel ( type ) } ]->(target) SET rel += $props RETURN rel";
73+ var parameters = new Dictionary < string , object ? >
74+ {
75+ [ "source_id" ] = sourceId ,
76+ [ "target_id" ] = targetId ,
77+ [ "props" ] = ConvertProperties ( properties )
78+ } ;
79+
80+ await ExecuteCypherAsync ( query , parameters , cancellationToken ) . ConfigureAwait ( false ) ;
6781 _logger . LogDebug ( "Upserted relationship {Source}-[{Type}]->{Target} in graph {GraphName}." , sourceId , type , targetId , _graphName ) ;
6882 }
6983
@@ -76,17 +90,25 @@ async IAsyncEnumerable<GraphRelationship> FetchAsync(string nodeId, [EnumeratorC
7690 {
7791 await using var connection = await OpenConnectionAsync ( token ) . ConfigureAwait ( false ) ;
7892 await using var command = connection . CreateCommand ( ) ;
79- command . CommandText = $ @ "
93+ command . CommandText = @"
8094SELECT
8195 source_id::text,
8296 target_id::text,
8397 edge_type::text,
8498 edge_props::text
85- FROM cypher(' { _graphName } ' , $$
86- MATCH (source {{ id: ' { EscapeString ( nodeId ) } ' } })-[rel]->(target)
99+ FROM cypher(@graph_name , $$
100+ MATCH (source { id: $node_id })-[rel]->(target)
87101 RETURN source.id AS source_id, target.id AS target_id, type(rel) AS edge_type, properties(rel) AS edge_props
88- $$) AS (source_id agtype, target_id agtype, edge_type agtype, edge_props agtype);
102+ $$, @params ) AS (source_id agtype, target_id agtype, edge_type agtype, edge_props agtype);
89103" ;
104+ command . Parameters . AddWithValue ( "graph_name" , _graphName ) ;
105+ command . Parameters . Add ( new NpgsqlParameter ( "params" , NpgsqlDbType . Jsonb )
106+ {
107+ Value = JsonSerializer . Serialize ( new Dictionary < string , object ? >
108+ {
109+ [ "node_id" ] = nodeId
110+ } )
111+ } ) ;
90112
91113 await using var reader = await command . ExecuteReaderAsync ( token ) . ConfigureAwait ( false ) ;
92114 while ( await reader . ReadAsync ( token ) . ConfigureAwait ( false ) )
@@ -102,11 +124,20 @@ FROM cypher('{_graphName}', $$
102124 }
103125 }
104126
105- private async Task ExecuteCypherAsync ( string statement , CancellationToken cancellationToken )
127+ protected virtual async Task ExecuteCypherAsync ( string query , IReadOnlyDictionary < string , object ? > parameters , CancellationToken cancellationToken )
106128 {
107129 await using var connection = await OpenConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
108130 await using var command = connection . CreateCommand ( ) ;
109- command . CommandText = statement ;
131+ command . CommandText = @"
132+ SELECT *
133+ FROM cypher(@graph_name, @query, @params) AS (result agtype);
134+ " ;
135+ command . Parameters . AddWithValue ( "graph_name" , _graphName ) ;
136+ command . Parameters . AddWithValue ( "query" , query ) ;
137+ command . Parameters . Add ( new NpgsqlParameter ( "params" , NpgsqlDbType . Jsonb )
138+ {
139+ Value = JsonSerializer . Serialize ( parameters )
140+ } ) ;
110141 await command . ExecuteNonQueryAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
111142 }
112143
@@ -131,74 +162,6 @@ private static async Task ExecuteNonQueryAsync(NpgsqlConnection connection, stri
131162 await command . ExecuteNonQueryAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
132163 }
133164
134- private string BuildNodeUpsertCypher ( string id , string label , IReadOnlyDictionary < string , object ? > properties )
135- {
136- var builder = new StringBuilder ( ) ;
137- builder . AppendLine ( $ "SELECT * FROM cypher('{ _graphName } ', $$") ;
138- builder . Append ( " MERGE (n:" ) ;
139- builder . Append ( EscapeLabel ( label ) ) ;
140- builder . Append ( " { id: '" ) ;
141- builder . Append ( EscapeString ( id ) ) ;
142- builder . Append ( "' })" ) ;
143-
144- var setClause = BuildSetClause ( "n" , properties , excludeId : true ) ;
145- if ( ! string . IsNullOrEmpty ( setClause ) )
146- {
147- builder . AppendLine ( ) ;
148- builder . Append ( " " ) ;
149- builder . Append ( setClause ) ;
150- }
151-
152- builder . AppendLine ( ) ;
153- builder . AppendLine ( "RETURN n" ) ;
154- builder . Append ( "$$) AS (n agtype);" ) ;
155- return builder . ToString ( ) ;
156- }
157-
158- private string BuildRelationshipUpsertCypher ( string sourceId , string targetId , string type , IReadOnlyDictionary < string , object ? > properties )
159- {
160- var builder = new StringBuilder ( ) ;
161- builder . AppendLine ( $ "SELECT * FROM cypher('{ _graphName } ', $$") ;
162- builder . AppendLine ( $ " MATCH (source {{ id: '{ EscapeString ( sourceId ) } ' }}), (target {{ id: '{ EscapeString ( targetId ) } ' }})") ;
163- builder . Append ( " MERGE (source)-[rel:" ) ;
164- builder . Append ( EscapeLabel ( type ) ) ;
165- builder . Append ( "]->(target)" ) ;
166-
167- var setClause = BuildSetClause ( "rel" , properties , excludeId : false ) ;
168- if ( ! string . IsNullOrEmpty ( setClause ) )
169- {
170- builder . AppendLine ( ) ;
171- builder . Append ( " " ) ;
172- builder . Append ( setClause ) ;
173- }
174-
175- builder . AppendLine ( ) ;
176- builder . AppendLine ( "RETURN rel" ) ;
177- builder . Append ( "$$) AS (rel agtype);" ) ;
178- return builder . ToString ( ) ;
179- }
180-
181- private static string BuildSetClause ( string alias , IReadOnlyDictionary < string , object ? > properties , bool excludeId )
182- {
183- var assignments = new List < string > ( ) ;
184- foreach ( var ( key , value ) in properties )
185- {
186- if ( value is null )
187- {
188- continue ;
189- }
190-
191- if ( excludeId && string . Equals ( key , "id" , StringComparison . OrdinalIgnoreCase ) )
192- {
193- continue ;
194- }
195-
196- assignments . Add ( $ "{ EscapePropertyKey ( key ) } : { FormatValue ( value ) } ") ;
197- }
198-
199- return assignments . Count == 0 ? string . Empty : $ "SET { alias } += {{{string.Join(", ", assignments)}}}" ;
200- }
201-
202165 private static string EscapeLabel ( string label )
203166 {
204167 if ( string . IsNullOrWhiteSpace ( label ) )
@@ -217,88 +180,60 @@ private static string EscapeLabel(string label)
217180 return label ;
218181 }
219182
220- private static string EscapePropertyKey ( string key )
183+ private static IDictionary < string , object ? > ConvertProperties ( IReadOnlyDictionary < string , object ? > properties )
221184 {
222- if ( string . IsNullOrWhiteSpace ( key ) )
223- {
224- throw new ArgumentException ( "Property key cannot be null or whitespace." , nameof ( key ) ) ;
225- }
185+ var result = new Dictionary < string , object ? > ( StringComparer . OrdinalIgnoreCase ) ;
226186
227- foreach ( var ch in key )
187+ foreach ( var ( key , value ) in properties )
228188 {
229- if ( ! char . IsLetterOrDigit ( ch ) && ch != '_' )
189+ if ( value is null || string . Equals ( key , "id" , StringComparison . OrdinalIgnoreCase ) )
230190 {
231- throw new ArgumentException ( $ "Invalid character ' { ch } ' in property key ' { key } '." , nameof ( key ) ) ;
191+ continue ;
232192 }
193+
194+ result [ key ] = value switch
195+ {
196+ string s => s ,
197+ int or long or short or byte => Convert . ToInt64 ( value , CultureInfo . InvariantCulture ) ,
198+ float or double or decimal => Convert . ToDouble ( value , CultureInfo . InvariantCulture ) ,
199+ bool b => b ,
200+ DateTime dt => dt . ToUniversalTime ( ) ,
201+ IEnumerable < string > list => list . ToArray ( ) ,
202+ _ => value . ToString ( )
203+ } ;
233204 }
234205
235- return key ;
206+ return result ;
236207 }
237208
238- private static string EscapeString ( string value ) => value . Replace ( "'" , "''" , StringComparison . Ordinal ) ;
209+ private static string NormalizeAgTypeText ( string value ) => value . Trim ( '"' ) ;
239210
240- private static string FormatValue ( object value ) => value switch
211+ private static IReadOnlyDictionary < string , object ? > ParseProperties ( string json )
241212 {
242- null => "null" ,
243- string s => $ "'{ EscapeString ( s ) } '",
244- bool b => b ? "true" : "false" ,
245- int or long or short or byte => Convert . ToString ( value , CultureInfo . InvariantCulture ) ! ,
246- float f => f . ToString ( CultureInfo . InvariantCulture ) ,
247- double d => d . ToString ( CultureInfo . InvariantCulture ) ,
248- decimal dec => dec . ToString ( CultureInfo . InvariantCulture ) ,
249- Guid guid => $ "'{ guid : D} '",
250- DateTime dt => $ "'{ dt . ToUniversalTime ( ) : O} '",
251- DateTimeOffset dto => $ "'{ dto . ToUniversalTime ( ) : O} '",
252- _ => $ "'{ EscapeString ( value . ToString ( ) ?? string . Empty ) } '"
253- } ;
254-
255- private static Dictionary < string , object ? > ParseProperties ( string json )
256- {
257- if ( string . IsNullOrWhiteSpace ( json ) )
213+ try
258214 {
259- return new Dictionary < string , object ? > ( StringComparer . OrdinalIgnoreCase ) ;
260- }
215+ using var document = JsonDocument . Parse ( json ) ;
216+ var result = new Dictionary < string , object ? > ( StringComparer . OrdinalIgnoreCase ) ;
217+ foreach ( var property in document . RootElement . EnumerateObject ( ) )
218+ {
219+ result [ property . Name ] = property . Value . ValueKind switch
220+ {
221+ JsonValueKind . String => property . Value . GetString ( ) ,
222+ JsonValueKind . Number => property . Value . TryGetInt64 ( out var i64 ) ? i64 : property . Value . GetDouble ( ) ,
223+ JsonValueKind . True => true ,
224+ JsonValueKind . False => false ,
225+ JsonValueKind . Null => null ,
226+ _ => property . Value . GetRawText ( )
227+ } ;
228+ }
261229
262- using var document = JsonDocument . Parse ( json ) ;
263- if ( document . RootElement . ValueKind != JsonValueKind . Object )
264- {
265- return new Dictionary < string , object ? > ( StringComparer . OrdinalIgnoreCase ) ;
230+ return result ;
266231 }
267-
268- var result = new Dictionary < string , object ? > ( StringComparer . OrdinalIgnoreCase ) ;
269- foreach ( var property in document . RootElement . EnumerateObject ( ) )
232+ catch
270233 {
271- result [ property . Name ] = ConvertJsonElement ( property . Value ) ;
234+ return new Dictionary < string , object ? > ( StringComparer . OrdinalIgnoreCase ) ;
272235 }
273-
274- return result ;
275236 }
276237
277- private static object ? ConvertJsonElement ( JsonElement element ) => element . ValueKind switch
278- {
279- JsonValueKind . Null => null ,
280- JsonValueKind . String => element . GetString ( ) ,
281- JsonValueKind . Number when element . TryGetInt64 ( out var longValue ) => longValue ,
282- JsonValueKind . Number => element . GetDouble ( ) ,
283- JsonValueKind . True => true ,
284- JsonValueKind . False => false ,
285- JsonValueKind . Array => element . EnumerateArray ( ) . Select ( ConvertJsonElement ) . ToArray ( ) ,
286- JsonValueKind . Object => element . EnumerateObject ( ) . ToDictionary ( prop => prop . Name , prop => ConvertJsonElement ( prop . Value ) , StringComparer . OrdinalIgnoreCase ) ,
287- _ => null
288- } ;
289-
290- private static string NormalizeAgTypeText ( string value )
291- {
292- if ( string . IsNullOrEmpty ( value ) )
293- {
294- return value ;
295- }
296-
297- if ( value . Length >= 2 && ( ( value [ 0 ] == '"' && value [ ^ 1 ] == '"' ) || ( value [ 0 ] == '\' ' && value [ ^ 1 ] == '\' ' ) ) )
298- {
299- return value [ 1 ..^ 1 ] ;
300- }
301-
302- return value ;
303- }
238+ public ValueTask DisposeAsync ( ) => ValueTask . CompletedTask ;
304239}
0 commit comments