From 641b23a4925e08534d02c6f274e6a56c729a583b Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Fri, 5 Feb 2010 18:01:43 -0500 Subject: Changing get to return a list of all items. --- kronos/storage.py | 31 +++++++++++++++++++------------ kronos/tests/itest_sqlite_storage.py | 14 ++++++++++---- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/kronos/storage.py b/kronos/storage.py index 4add988..2cb37a7 100644 --- a/kronos/storage.py +++ b/kronos/storage.py @@ -41,16 +41,20 @@ class BaseStorageBackEnd(object): back-end. """ name = model_obj.__name__ - result = self.select(name, **kwargs) - instance = model_obj() - for key, value in result.items(): - if key != 'id': - setattr(instance, key, value) - else: - instance.__db_id__ = value + output = [] + for item in self.select(name, **kwargs): + instance = model_obj() + + for key, value in item.items(): + if key != 'id': + setattr(instance, key, value) + else: + instance.__db_id__ = value - return instance + output.append(instance) + + return output def save(self, model_obj): """ @@ -121,11 +125,14 @@ class SQLiteBackEnd(BaseStorageBackEnd): self.connection = self.engine.connect(database) def select(self, table, **kwargs): - sql = "SELECT * FROM {0} WHERE ".format(table) - for key in kwargs.keys(): - sql += "{0}=? ".format(key) + sql = "SELECT * FROM {0}".format(table) + + if kwargs: + sql += " WHERE " + for key in kwargs.keys(): + sql += "{0}=? ".format(key) - return self._get_normalized_results(sql, **kwargs)[0] + return self._get_normalized_results(sql, **kwargs) def _get_normalized_results(self, sql, **kwargs): self._check_connection() diff --git a/kronos/tests/itest_sqlite_storage.py b/kronos/tests/itest_sqlite_storage.py index 4628a23..e2a44d2 100644 --- a/kronos/tests/itest_sqlite_storage.py +++ b/kronos/tests/itest_sqlite_storage.py @@ -37,13 +37,14 @@ class TestSQLitBackEnd(unittest.TestCase): self.model1.bar = '456' self.model1.baz = '789' + self.storage.save(self.model1) + def test_no_connect_should_cause_error(self): self.storage.connection = None assert_raises(NotConnected, self.storage.save, self.model1) def test_save_and_select(self): - self.storage.save(self.model1) - results = self.storage.get(SampleModel, foo='123') + results = self.storage.get(SampleModel, foo='123')[0] assert isinstance(results, SampleModel) assert_equals(results.foo, '123') @@ -51,8 +52,7 @@ class TestSQLitBackEnd(unittest.TestCase): assert_equals(results.baz, '789') def test_save_and_update(self): - self.storage.save(self.model1) - results = self.storage.get(SampleModel, foo='123') + results = self.storage.get(SampleModel, foo='123')[0] results.foo = 'test' self.storage.save(results) @@ -61,6 +61,12 @@ class TestSQLitBackEnd(unittest.TestCase): assert_equals(results.bar, '456') assert_equals(results.baz, '789') + def test_get_without_args(self): + self.storage.save(self.model1) + results = self.storage.get(SampleModel) + + assert_equals(len(results), 2) + if __name__ == "__main__": unittest.main() -- cgit v1.2.3