summaryrefslogtreecommitdiff
path: root/libpod
diff options
context:
space:
mode:
Diffstat (limited to 'libpod')
-rw-r--r--libpod/errors.go3
-rw-r--r--libpod/sql_state.go9
-rw-r--r--libpod/sql_state_internal.go131
-rw-r--r--libpod/sql_state_test.go7
4 files changed, 149 insertions, 1 deletions
diff --git a/libpod/errors.go b/libpod/errors.go
index 180ca51db..782104cf0 100644
--- a/libpod/errors.go
+++ b/libpod/errors.go
@@ -59,6 +59,9 @@ var (
// ErrDBClosed indicates that the connection to the state database has
// already been closed
ErrDBClosed = errors.New("database connection already closed")
+ // ErrDBBadConfig indicates that the database has a different schema or
+ // was created by a libpod with a different config
+ ErrDBBadConfig = errors.New("database configuration mismatch")
// ErrNotImplemented indicates that the requested functionality is not
// yet present
diff --git a/libpod/sql_state.go b/libpod/sql_state.go
index 8b18a8959..7c2061fca 100644
--- a/libpod/sql_state.go
+++ b/libpod/sql_state.go
@@ -14,6 +14,10 @@ import (
_ "github.com/mattn/go-sqlite3"
)
+// DBSchema is the current DB schema version
+// Increments every time a change is made to the database's tables
+const DBSchema = 1
+
// SQLState is a state implementation backed by a persistent SQLite3 database
type SQLState struct {
db *sql.DB
@@ -69,6 +73,11 @@ func NewSQLState(dbPath, lockPath, specsDir string, runtime *Runtime) (State, er
return nil, err
}
+ // Ensure that the database matches our config
+ if err := checkDB(db, runtime); err != nil {
+ return nil, err
+ }
+
state.db = db
state.valid = true
diff --git a/libpod/sql_state_internal.go b/libpod/sql_state_internal.go
index 58a6daa58..6e0142b9b 100644
--- a/libpod/sql_state_internal.go
+++ b/libpod/sql_state_internal.go
@@ -15,6 +15,137 @@ import (
_ "github.com/mattn/go-sqlite3"
)
+// Checks that the DB configuration matches the runtime's configuration
+func checkDB(db *sql.DB, r *Runtime) (err error) {
+ // Create a table to hold runtime information
+ // TODO: Include UID/GID mappings
+ const runtimeTable = `
+ CREATE TABLE runtime(
+ Id INTEGER NOT NULL PRIMARY KEY,
+ SchemaVersion INTEGER NOT NULL,
+ StaticDir TEXT NOT NULL,
+ TmpDir TEXT NOT NULL,
+ RunRoot TEXT NOT NULL,
+ GraphRoot TEXT NOT NULL,
+ GraphDriverName TEXT NOT NULL,
+ CHECK (Id=0)
+ );
+ `
+ const fillRuntimeTable = `INSERT INTO runtime VALUES (
+ ?, ?, ?, ?, ?, ?, ?
+ );`
+
+ const selectRuntimeTable = `SELECT SchemaVersion,
+ StaticDir,
+ TmpDir,
+ RunRoot,
+ GraphRoot,
+ GraphDriverName
+ FROM runtime WHERE id=0;`
+
+ const checkRuntimeExists = "SELECT name FROM sqlite_master WHERE type='table' AND name='runtime';"
+
+ tx, err := db.Begin()
+ if err != nil {
+ return errors.Wrapf(err, "error beginning database transaction")
+ }
+ defer func() {
+ if err != nil {
+ if err2 := tx.Rollback(); err2 != nil {
+ logrus.Errorf("Error rolling back transaction to check runtime table: %v", err2)
+ }
+ }
+
+ }()
+
+ row := tx.QueryRow(checkRuntimeExists)
+ var table string
+ if err := row.Scan(&table); err != nil {
+ // There is no runtime table
+ // Create and populate the runtime table
+ if err == sql.ErrNoRows {
+ if _, err := tx.Exec(runtimeTable); err != nil {
+ return errors.Wrapf(err, "error creating runtime table in database")
+ }
+
+ _, err := tx.Exec(fillRuntimeTable,
+ 0,
+ DBSchema,
+ r.config.StaticDir,
+ r.config.TmpDir,
+ r.config.StorageConfig.RunRoot,
+ r.config.StorageConfig.GraphRoot,
+ r.config.StorageConfig.GraphDriverName)
+ if err != nil {
+ return errors.Wrapf(err, "error populating runtime table in database")
+ }
+
+ if err := tx.Commit(); err != nil {
+ return errors.Wrapf(err, "error committing runtime table transaction in database")
+ }
+
+ return nil
+ }
+
+ return errors.Wrapf(err, "error checking for presence of runtime table in database")
+ }
+
+ // There is a runtime table
+ // Retrieve its contents
+ var (
+ schemaVersion int
+ staticDir string
+ tmpDir string
+ runRoot string
+ graphRoot string
+ graphDriverName string
+ )
+
+ row = tx.QueryRow(selectRuntimeTable)
+ err = row.Scan(
+ &schemaVersion,
+ &staticDir,
+ &tmpDir,
+ &runRoot,
+ &graphRoot,
+ &graphDriverName)
+ if err != nil {
+ return errors.Wrapf(err, "error retrieving runtime information from database")
+ }
+
+ // Compare the information in the database against our runtime config
+ if schemaVersion != DBSchema {
+ return errors.Wrapf(ErrDBBadConfig, "database schema version %d does not match our schema version %d",
+ schemaVersion, DBSchema)
+ }
+ if staticDir != r.config.StaticDir {
+ return errors.Wrapf(ErrDBBadConfig, "database static directory %s does not match our static directory %s",
+ staticDir, r.config.StaticDir)
+ }
+ if tmpDir != r.config.TmpDir {
+ return errors.Wrapf(ErrDBBadConfig, "database temp directory %s does not match our temp directory %s",
+ tmpDir, r.config.TmpDir)
+ }
+ if runRoot != r.config.StorageConfig.RunRoot {
+ return errors.Wrapf(ErrDBBadConfig, "database runroot directory %s does not match our runroot directory %s",
+ runRoot, r.config.StorageConfig.RunRoot)
+ }
+ if graphRoot != r.config.StorageConfig.GraphRoot {
+ return errors.Wrapf(ErrDBBadConfig, "database graph root directory %s does not match our graph root directory %s",
+ graphRoot, r.config.StorageConfig.GraphRoot)
+ }
+ if graphDriverName != r.config.StorageConfig.GraphDriverName {
+ return errors.Wrapf(ErrDBBadConfig, "database runroot directory %s does not match our runroot directory %s",
+ graphDriverName, r.config.StorageConfig.GraphDriverName)
+ }
+
+ if err := tx.Commit(); err != nil {
+ return errors.Wrapf(err, "error committing runtime table transaction in database")
+ }
+
+ return nil
+}
+
// Performs database setup including by not limited to initializing tables in
// the database
func prepareDB(db *sql.DB) (err error) {
diff --git a/libpod/sql_state_test.go b/libpod/sql_state_test.go
index 9f6b5d078..124959544 100644
--- a/libpod/sql_state_test.go
+++ b/libpod/sql_state_test.go
@@ -9,6 +9,7 @@ import (
"testing"
"time"
+ "github.com/containers/storage"
"github.com/opencontainers/runtime-tools/generate"
"github.com/stretchr/testify/assert"
)
@@ -102,7 +103,11 @@ func getEmptyState() (s State, p string, err error) {
dbPath := filepath.Join(tmpDir, "db.sql")
lockPath := filepath.Join(tmpDir, "db.lck")
- state, err := NewSQLState(dbPath, lockPath, tmpDir, nil)
+ runtime := new(Runtime)
+ runtime.config = new(RuntimeConfig)
+ runtime.config.StorageConfig = storage.StoreOptions{}
+
+ state, err := NewSQLState(dbPath, lockPath, tmpDir, runtime)
if err != nil {
return nil, "", err
}