diff --git a/lnbits/db.py b/lnbits/db.py index d9a86609..316bb217 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -15,22 +15,26 @@ class Database: return self def __exit__(self, exc_type, exc_val, exc_tb): + self.connection.commit() self.cursor.close() self.connection.close() def fetchall(self, query: str, values: tuple = ()) -> list: """Given a query, return cursor.fetchall() rows.""" - self.cursor.execute(query, values) + self.execute(query, values) return self.cursor.fetchall() def fetchone(self, query: str, values: tuple = ()): - self.cursor.execute(query, values) + self.execute(query, values) return self.cursor.fetchone() def execute(self, query: str, values: tuple = ()) -> None: """Given a query, cursor.execute() it.""" - self.cursor.execute(query, values) - self.connection.commit() + try: + self.cursor.execute(query, values) + except sqlite3.Error as exc: + self.connection.rollback() + raise exc def open_db(db_name: str = "database") -> Database: