Skip to content

Commit 6b5129d

Browse files
committed
Improved SQLiteAdapter to be thread-safe (tested with Xcode 15.3 and Swift 5.10)
1 parent 1884b99 commit 6b5129d

1 file changed

Lines changed: 115 additions & 93 deletions

File tree

Sources/SQLiteAdapter/SQLiteAdapter.swift

Lines changed: 115 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ open class SQLite: SQLiteType {
8282
private let SQLITE_STATIC = unsafeBitCast(0, to: sqlite3_destructor_type.self)
8383
private let SQLITE_TRANSIENT = unsafeBitCast(-1, to: sqlite3_destructor_type.self)
8484

85+
private let serialQueue = DispatchQueue(label: "SQLite Serial Queue")
86+
8587
public var dateFormatter = DateFormatter()
8688

8789
public init(path: String, recreateDB: Bool = false) throws {
@@ -193,15 +195,18 @@ open class SQLite: SQLiteType {
193195
}
194196

195197
private func operation(sql: String, params: [Any]? = nil) throws {
196-
let sqlStatement = try prepareStatement(sql: sql)
197-
defer {
198-
sqlite3_finalize(sqlStatement)
199-
}
200-
201-
try bindPlaceholders(sqlStatement: sqlStatement, params: params)
202-
203-
guard sqlite3_step(sqlStatement) == SQLITE_DONE else {
204-
throw SQLiteError.Step(getErrorMessage(dbPointer: dbPointer))
198+
try serialQueue.sync {
199+
let sqlStatement = try prepareStatement(sql: sql)
200+
201+
defer {
202+
sqlite3_finalize(sqlStatement)
203+
}
204+
205+
try bindPlaceholders(sqlStatement: sqlStatement, params: params)
206+
207+
guard sqlite3_step(sqlStatement) == SQLITE_DONE else {
208+
throw SQLiteError.Step(getErrorMessage(dbPointer: dbPointer))
209+
}
205210
}
206211
}
207212

@@ -331,36 +336,41 @@ open class SQLite: SQLiteType {
331336
}
332337

333338
public func getRowCount(in table: SQLTable) throws -> Int {
334-
let sql = "SELECT count(*) FROM \(table.name);"
335-
let sqlStatement = try prepareStatement(sql: sql)
336-
defer {
337-
sqlite3_finalize(sqlStatement)
338-
}
339-
guard sqlite3_step(sqlStatement) == SQLITE_ROW else {
340-
throw SQLiteError.Step(getErrorMessage(dbPointer: dbPointer))
339+
var count: Int32 = 0
340+
try serialQueue.sync {
341+
let sql = "SELECT count(*) FROM \(table.name);"
342+
let sqlStatement = try prepareStatement(sql: sql)
343+
defer {
344+
sqlite3_finalize(sqlStatement)
345+
}
346+
guard sqlite3_step(sqlStatement) == SQLITE_ROW else {
347+
throw SQLiteError.Step(getErrorMessage(dbPointer: dbPointer))
348+
}
349+
count = sqlite3_column_int(sqlStatement, 0)
350+
log("successfully got a row count in \(table.name): \(count)")
341351
}
342-
let count = sqlite3_column_int(sqlStatement, 0)
343-
log("successfully got a row count in \(table.name): \(count)")
344352
return Int(count)
345353
}
346354

347355
public func getRowCountWithCondition(sql: String, params: [Any]? = nil) throws -> Int {
348356
guard sql.uppercased().trimmingCharacters(in: .whitespaces).hasPrefix("SELECT ") else {
349357
throw SQLiteError.Statement("Invalid SQL statement")
350358
}
351-
352-
let sqlStatement = try prepareStatement(sql: sql)
353-
defer {
354-
sqlite3_finalize(sqlStatement)
355-
}
356-
357-
try bindPlaceholders(sqlStatement: sqlStatement, params: params)
358-
359-
guard sqlite3_step(sqlStatement) == SQLITE_ROW else {
360-
throw SQLiteError.Step(getErrorMessage(dbPointer: dbPointer))
359+
var count: Int32 = 0
360+
try serialQueue.sync {
361+
let sqlStatement = try prepareStatement(sql: sql)
362+
defer {
363+
sqlite3_finalize(sqlStatement)
364+
}
365+
366+
try bindPlaceholders(sqlStatement: sqlStatement, params: params)
367+
368+
guard sqlite3_step(sqlStatement) == SQLITE_ROW else {
369+
throw SQLiteError.Step(getErrorMessage(dbPointer: dbPointer))
370+
}
371+
count = sqlite3_column_int(sqlStatement, 0)
372+
log("successfully got a row count with condition: \(count), sql: \(sql)")
361373
}
362-
let count = sqlite3_column_int(sqlStatement, 0)
363-
log("successfully got a row count with condition: \(count), sql: \(sql)")
364374
return Int(count)
365375
}
366376

@@ -370,75 +380,78 @@ open class SQLite: SQLiteType {
370380
throw SQLiteError.Statement("Invalid SQL statement")
371381
}
372382

373-
let sqlStatement = try prepareStatement(sql: sql)
374-
defer {
375-
sqlite3_finalize(sqlStatement)
376-
}
377-
378-
try bindPlaceholders(sqlStatement: sqlStatement, params: params)
379-
380383
var allRows: [SQLValues] = []
381-
var rowValues: SQLValues = SQLValues([])
382384

383-
guard let resultColumns = try? getResultColumns(table, sqlStatement: sqlStatement) else {
384-
throw SQLiteError.Column(getErrorMessage(dbPointer: dbPointer))
385-
}
386-
387-
while sqlite3_step(sqlStatement) == SQLITE_ROW {
388-
rowValues = SQLValues([])
389-
for (index, value) in resultColumns.enumerated() {
390-
391-
let index = Int32(index) // column serial number, should start with 0
392-
393-
// Check for data types of returned values
394-
guard sqlite3_column_type(sqlStatement, index) != SQLITE_NULL else {
395-
rowValues.append((value.type, nil))
396-
continue
397-
}
398-
switch value.type {
399-
case .INT:
400-
let intValue = sqlite3_column_int64(sqlStatement, index)
401-
rowValues.append((value.type, Int(intValue)))
402-
case .BOOL:
403-
let intValue = sqlite3_column_int(sqlStatement, index)
404-
rowValues.append((value.type, intValue == 1 ? true : false))
405-
case .TEXT:
406-
if let queryResult = sqlite3_column_text(sqlStatement, index) {
407-
let stringValue = String(cString: queryResult)
408-
rowValues.append((value.type, stringValue))
409-
} else {
410-
rowValues.append((value.type, nil))
411-
}
412-
case .REAL:
413-
let doubleValue = sqlite3_column_double(sqlStatement, index)
414-
rowValues.append((value.type, doubleValue))
415-
case .BLOB:
416-
if let queryResult = sqlite3_column_blob(sqlStatement, index) {
417-
let count = sqlite3_column_bytes(sqlStatement, index)
418-
let dataValue = Data(bytes: queryResult, count: Int(count))
419-
rowValues.append((value.type, dataValue))
420-
} else {
385+
try serialQueue.sync {
386+
let sqlStatement = try prepareStatement(sql: sql)
387+
defer {
388+
sqlite3_finalize(sqlStatement)
389+
}
390+
391+
try bindPlaceholders(sqlStatement: sqlStatement, params: params)
392+
393+
var rowValues: SQLValues = SQLValues([])
394+
395+
guard let resultColumns = try? getResultColumns(table, sqlStatement: sqlStatement) else {
396+
throw SQLiteError.Column(getErrorMessage(dbPointer: dbPointer))
397+
}
398+
399+
while sqlite3_step(sqlStatement) == SQLITE_ROW {
400+
rowValues = SQLValues([])
401+
for (index, value) in resultColumns.enumerated() {
402+
403+
let index = Int32(index) // column serial number, should start with 0
404+
405+
// Check for data types of returned values
406+
guard sqlite3_column_type(sqlStatement, index) != SQLITE_NULL else {
421407
rowValues.append((value.type, nil))
408+
continue
422409
}
423-
case .DATE:
424-
// If it's in date format
425-
if let queryResult = sqlite3_column_text(sqlStatement, index) {
426-
var dateStrValue = String(cString: queryResult)
427-
if dateStrValue.count == 10 {
428-
dateStrValue += " 00:00:00"
410+
switch value.type {
411+
case .INT:
412+
let intValue = sqlite3_column_int64(sqlStatement, index)
413+
rowValues.append((value.type, Int(intValue)))
414+
case .BOOL:
415+
let intValue = sqlite3_column_int(sqlStatement, index)
416+
rowValues.append((value.type, intValue == 1 ? true : false))
417+
case .TEXT:
418+
if let queryResult = sqlite3_column_text(sqlStatement, index) {
419+
let stringValue = String(cString: queryResult)
420+
rowValues.append((value.type, stringValue))
421+
} else {
422+
rowValues.append((value.type, nil))
423+
}
424+
case .REAL:
425+
let doubleValue = sqlite3_column_double(sqlStatement, index)
426+
rowValues.append((value.type, doubleValue))
427+
case .BLOB:
428+
if let queryResult = sqlite3_column_blob(sqlStatement, index) {
429+
let count = sqlite3_column_bytes(sqlStatement, index)
430+
let dataValue = Data(bytes: queryResult, count: Int(count))
431+
rowValues.append((value.type, dataValue))
432+
} else {
433+
rowValues.append((value.type, nil))
429434
}
430-
if let dateValue = dateFormatter.date(from: dateStrValue) {
431-
rowValues.append((value.type, dateValue))
432-
continue
435+
case .DATE:
436+
// If it's in date format
437+
if let queryResult = sqlite3_column_text(sqlStatement, index) {
438+
var dateStrValue = String(cString: queryResult)
439+
if dateStrValue.count == 10 {
440+
dateStrValue += " 00:00:00"
441+
}
442+
if let dateValue = dateFormatter.date(from: dateStrValue) {
443+
rowValues.append((value.type, dateValue))
444+
continue
445+
}
433446
}
447+
// If it's in time interval format
448+
let timeInterval = sqlite3_column_double(sqlStatement, index)
449+
let dateValue = Date(timeIntervalSince1970: timeInterval)
450+
rowValues.append((value.type, dateValue))
434451
}
435-
// If it's in time interval format
436-
let timeInterval = sqlite3_column_double(sqlStatement, index)
437-
let dateValue = Date(timeIntervalSince1970: timeInterval)
438-
rowValues.append((value.type, dateValue))
439452
}
453+
allRows.append(rowValues)
440454
}
441-
allRows.append(rowValues)
442455
}
443456

444457
log("successfully read row(s), count: \(allRows.count), sql: \(sql)")
@@ -499,21 +512,30 @@ open class SQLite: SQLiteType {
499512
}
500513

501514
public func getLastInsertID() -> Int {
502-
let id = Int(sqlite3_last_insert_rowid(dbPointer))
515+
var id = 0
516+
serialQueue.sync {
517+
id = Int(sqlite3_last_insert_rowid(dbPointer))
518+
}
503519
log("last inserted id: \(id)")
504520
return id
505521
}
506522

507523
/// Returns number of rows changed by last INSERT, UPDATE or DELETE statement
508524
public func getChanges() -> Int {
509-
let changes = Int(sqlite3_changes(dbPointer))
525+
var changes = 0
526+
serialQueue.sync {
527+
changes = Int(sqlite3_changes(dbPointer))
528+
}
510529
log("number of changes: \(changes)")
511530
return changes
512531
}
513532

514533
/// Returns number of rows changed by INSERT, UPDATE or DELETE statements since the DB was opened
515534
public func getTotalChanges() -> Int {
516-
let totalChanges = Int(sqlite3_total_changes(dbPointer))
535+
var totalChanges = 0
536+
serialQueue.sync {
537+
totalChanges = Int(sqlite3_total_changes(dbPointer))
538+
}
517539
log("number of total changes: \(totalChanges)")
518540
return totalChanges
519541
}

0 commit comments

Comments
 (0)