diff options
Diffstat (limited to 'libpod')
-rw-r--r-- | libpod/container.go | 2 | ||||
-rw-r--r-- | libpod/sql_state.go | 55 | ||||
-rw-r--r-- | libpod/sql_state_test.go | 520 |
3 files changed, 558 insertions, 19 deletions
diff --git a/libpod/container.go b/libpod/container.go index ffc4c6314..728a29dec 100644 --- a/libpod/container.go +++ b/libpod/container.go @@ -67,7 +67,7 @@ type containerRuntimeInfo struct { RunDir string `json:"runDir,omitempty"` // Mounted indicates whether the container's storage has been mounted // for use - Mounted bool `json:"-"` + Mounted bool `json:"mounted,omitempty"` // MountPoint contains the path to the container's mounted storage Mountpoint string `json:"mountPoint,omitempty"` // StartedTime is the time the container was started diff --git a/libpod/sql_state.go b/libpod/sql_state.go index 034fc03e1..466ad66a5 100644 --- a/libpod/sql_state.go +++ b/libpod/sql_state.go @@ -109,6 +109,10 @@ func (s *SQLState) Container(id string) (*Container, error) { containerState ON containers.Id = containerState.Id WHERE containers.Id=?;` + if id == "" { + return nil, ErrEmptyID + } + if !s.valid { return nil, ErrDBClosed } @@ -138,6 +142,10 @@ func (s *SQLState) LookupContainer(idOrName string) (*Container, error) { containerState ON containers.Id = containerState.Id WHERE (containers.Id LIKE ?) OR containers.Name=?;` + if idOrName == "" { + return nil, ErrEmptyID + } + if !s.valid { return nil, ErrDBClosed } @@ -178,6 +186,10 @@ func (s *SQLState) LookupContainer(idOrName string) (*Container, error) { func (s *SQLState) HasContainer(id string) (bool, error) { const query = "SELECT 1 FROM containers WHERE Id=?;" + if id == "" { + return false, ErrEmptyID + } + if !s.valid { return false, ErrDBClosed } @@ -225,23 +237,6 @@ func (s *SQLState) AddContainer(ctr *Container) (err error) { return errors.Wrapf(err, "error marshaling container %s labels to JSON", ctr.ID()) } - // Save the container's runtime spec to disk - specJSON, err := json.Marshal(ctr.config.Spec) - if err != nil { - return errors.Wrapf(err, "error marshalling container %s spec to JSON", ctr.ID()) - } - specPath := getSpecPath(s.specsDir, ctr.ID()) - if err := ioutil.WriteFile(specPath, specJSON, 0750); err != nil { - return errors.Wrapf(err, "error saving container %s spec JSON to disk", ctr.ID()) - } - defer func() { - if err != nil { - if err2 := os.Remove(specPath); err2 != nil { - logrus.Errorf("Error removing container %s JSON spec from state: %v", ctr.ID(), err2) - } - } - }() - s.lock.Lock() defer s.lock.Unlock() @@ -288,6 +283,23 @@ func (s *SQLState) AddContainer(ctr *Container) (err error) { return errors.Wrapf(err, "error adding container %s state to database", ctr.ID()) } + // Save the container's runtime spec to disk + specJSON, err := json.Marshal(ctr.config.Spec) + if err != nil { + return errors.Wrapf(err, "error marshalling container %s spec to JSON", ctr.ID()) + } + specPath := getSpecPath(s.specsDir, ctr.ID()) + if err := ioutil.WriteFile(specPath, specJSON, 0750); err != nil { + return errors.Wrapf(err, "error saving container %s spec JSON to disk", ctr.ID()) + } + defer func() { + if err != nil { + if err2 := os.Remove(specPath); err2 != nil { + logrus.Errorf("Error removing container %s JSON spec from state: %v", ctr.ID(), err2) + } + } + }() + if err := tx.Commit(); err != nil { return errors.Wrapf(err, "error committing transaction to add container %s", ctr.ID()) } @@ -411,7 +423,7 @@ func (s *SQLState) SaveContainer(ctr *Container) error { }() // Add container state to the database - _, err = tx.Exec(update, + result, err := tx.Exec(update, ctr.state.State, ctr.state.ConfigPath, ctr.state.RunDir, @@ -423,6 +435,13 @@ func (s *SQLState) SaveContainer(ctr *Container) error { if err != nil { return errors.Wrapf(err, "error updating container %s state in database", ctr.ID()) } + rows, err := result.RowsAffected() + if err != nil { + return errors.Wrapf(err, "error retrieving number of rows modified by update of container %s", ctr.ID()) + } + if rows == 0 { + return ErrNoSuchCtr + } if err := tx.Commit(); err != nil { return errors.Wrapf(err, "error committing transaction to update container %s", ctr.ID()) diff --git a/libpod/sql_state_test.go b/libpod/sql_state_test.go new file mode 100644 index 000000000..b50f3aced --- /dev/null +++ b/libpod/sql_state_test.go @@ -0,0 +1,520 @@ +package libpod + +import ( + "encoding/json" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + "time" + + "github.com/opencontainers/runtime-tools/generate" + "github.com/stretchr/testify/assert" +) + +func getTestContainer(id, name string) *Container { + ctr := &Container{ + config: &containerConfig{ + ID: id, + Name: name, + RootfsImageID: id, + RootfsImageName: "testimg", + UseImageConfig: true, + StaticDir: "/does/not/exist/", + Stdin: true, + Labels: make(map[string]string), + StopSignal: 0, + CreatedTime: time.Now(), + }, + state: &containerRuntimeInfo{ + State: ContainerStateRunning, + ConfigPath: "/does/not/exist/specs/" + id, + RunDir: "/does/not/exist/tmp/", + Mounted: true, + Mountpoint: "/does/not/exist/tmp/" + id, + }, + valid: true, + } + + g := generate.New() + ctr.config.Spec = g.Spec() + + ctr.config.Labels["test"] = "testing" + + return ctr +} + +// This horrible hack tests if containers are equal in a way that should handle +// empty arrays being dropped to nil pointers in the spec JSON +func testContainersEqual(a, b *Container) bool { + if a == nil && b == nil { + return true + } else if a == nil || b == nil { + return false + } + + if a.valid != b.valid { + return false + } + + aConfigJSON, err := json.Marshal(a.config) + if err != nil { + return false + } + + bConfigJSON, err := json.Marshal(b.config) + if err != nil { + return false + } + + if !reflect.DeepEqual(aConfigJSON, bConfigJSON) { + return false + } + + aStateJSON, err := json.Marshal(a.state) + if err != nil { + return false + } + + bStateJSON, err := json.Marshal(b.state) + if err != nil { + return false + } + + return reflect.DeepEqual(aStateJSON, bStateJSON) +} + +// Get an empty state for use in tests +// An empty Runtime is provided +func getEmptyState() (s State, p string, err error) { + tmpDir, err := ioutil.TempDir("", "libpod_state_test_") + if err != nil { + return nil, "", err + } + defer func() { + if err != nil { + os.RemoveAll(tmpDir) + } + }() + + dbPath := filepath.Join(tmpDir, "db.sql") + lockPath := filepath.Join(tmpDir, "db.lck") + + state, err := NewSQLState(dbPath, lockPath, tmpDir, nil) + if err != nil { + return nil, "", err + } + + return state, tmpDir, nil +} + +func TestAddAndGetContainer(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + retrievedCtr, err := state.Container(testCtr.ID()) + assert.NoError(t, err) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr, retrievedCtr) { + assert.EqualValues(t, testCtr, retrievedCtr) + } +} + +func TestAddAndGetContainerFromMultiple(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr1 := getTestContainer("11111111111111111111111111111111", "test1") + testCtr2 := getTestContainer("22222222222222222222222222222222", "test2") + + err = state.AddContainer(testCtr1) + assert.NoError(t, err) + + err = state.AddContainer(testCtr2) + assert.NoError(t, err) + + retrievedCtr, err := state.Container(testCtr1.ID()) + assert.NoError(t, err) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr1, retrievedCtr) { + assert.EqualValues(t, testCtr1, retrievedCtr) + } +} + +func TestAddInvalidContainerFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + err = state.AddContainer(&Container{}) + assert.Error(t, err) +} + +func TestAddDuplicateIDFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr1 := getTestContainer("11111111111111111111111111111111", "test1") + testCtr2 := getTestContainer(testCtr1.ID(), "test2") + + err = state.AddContainer(testCtr1) + assert.NoError(t, err) + + err = state.AddContainer(testCtr2) + assert.Error(t, err) +} + +func TestAddDuplicateNameFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr1 := getTestContainer("11111111111111111111111111111111", "test1") + testCtr2 := getTestContainer("22222222222222222222222222222222", testCtr1.Name()) + + err = state.AddContainer(testCtr1) + assert.NoError(t, err) + + err = state.AddContainer(testCtr2) + assert.Error(t, err) +} + +func TestGetNonexistantContainerFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + _, err = state.Container("does not exist") + assert.Error(t, err) +} + +func TestGetContainerWithEmptyIDFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + _, err = state.Container("") + assert.Error(t, err) +} + +func TestLookupContainerWithEmptyIDFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + _, err = state.LookupContainer("") + assert.Error(t, err) +} + +func TestLookupNonexistantContainerFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + + _, err = state.LookupContainer("does not exist") + assert.Error(t, err) +} + +func TestLookupContainerByFullID(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + retrievedCtr, err := state.LookupContainer(testCtr.ID()) + assert.NoError(t, err) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr, retrievedCtr) { + assert.EqualValues(t, testCtr, retrievedCtr) + } +} + +func TestLookupContainerByUniquePartialID(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + retrievedCtr, err := state.LookupContainer(testCtr.ID()[0:8]) + assert.NoError(t, err) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr, retrievedCtr) { + assert.EqualValues(t, testCtr, retrievedCtr) + } +} + +func TestLookupContainerByNonUniquePartialIDFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr1 := getTestContainer("00000000000000000000000000000000", "test1") + testCtr2 := getTestContainer("00000000000000000000000000000001", "test2") + + err = state.AddContainer(testCtr1) + assert.NoError(t, err) + + err = state.AddContainer(testCtr2) + assert.NoError(t, err) + + _, err = state.LookupContainer(testCtr1.ID()[0:8]) + assert.Error(t, err) +} + +func TestLookupContainerByName(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + retrievedCtr, err := state.LookupContainer(testCtr.Name()) + assert.NoError(t, err) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr, retrievedCtr) { + assert.EqualValues(t, testCtr, retrievedCtr) + } +} + +func TestHasContainerEmptyIDFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + _, err = state.HasContainer("") + assert.Error(t, err) +} + +func TestHasContainerNoSuchContainerReturnsFalse(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + exists, err := state.HasContainer("does not exist") + assert.NoError(t, err) + assert.False(t, exists) +} + +func TestHasContainerFindsContainer(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + exists, err := state.HasContainer(testCtr.ID()) + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestSaveAndUpdateContainer(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + retrievedCtr, err := state.Container(testCtr.ID()) + assert.NoError(t, err) + + retrievedCtr.state.State = ContainerStateStopped + retrievedCtr.state.ExitCode = 127 + retrievedCtr.state.FinishedTime = time.Now() + + err = state.SaveContainer(retrievedCtr) + assert.NoError(t, err) + + err = state.UpdateContainer(testCtr) + assert.NoError(t, err) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr, retrievedCtr) { + assert.EqualValues(t, testCtr, retrievedCtr) + } +} + +func TestUpdateContainerNotInDatabaseReturnsError(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.UpdateContainer(testCtr) + assert.Error(t, err) + assert.False(t, testCtr.valid) +} + +func TestUpdateInvalidContainerReturnsError(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + err = state.UpdateContainer(&Container{}) + assert.Error(t, err) +} + +func TestSaveInvalidContainerReturnsError(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + err = state.SaveContainer(&Container{}) + assert.Error(t, err) +} + +func TestSaveContainerNotInStateReturnsError(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.SaveContainer(testCtr) + assert.Error(t, err) +} + +func TestRemoveContainer(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + ctrs, err := state.AllContainers() + assert.NoError(t, err) + assert.Equal(t, 1, len(ctrs)) + + err = state.RemoveContainer(testCtr) + assert.NoError(t, err) + + ctrs2, err := state.AllContainers() + assert.NoError(t, err) + assert.Equal(t, 0, len(ctrs2)) +} + +func TestRemoveNonexistantContainerFails(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.RemoveContainer(testCtr) + assert.Error(t, err) +} + +func TestGetAllContainersOnNewStateIsEmpty(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + ctrs, err := state.AllContainers() + assert.NoError(t, err) + assert.Equal(t, 0, len(ctrs)) +} + +func TestGetAllContainersWithOneContainer(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr := getTestContainer("0123456789ABCDEF0123456789ABCDEF", "test") + + err = state.AddContainer(testCtr) + assert.NoError(t, err) + + ctrs, err := state.AllContainers() + assert.NoError(t, err) + assert.Equal(t, 1, len(ctrs)) + + // Use assert.EqualValues if the test fails to pretty print diff + // between actual and expected + if !testContainersEqual(testCtr, ctrs[0]) { + assert.EqualValues(t, testCtr, ctrs[0]) + } +} + +func TestGetAllContainersTwoContainers(t *testing.T) { + state, path, err := getEmptyState() + assert.NoError(t, err) + defer os.RemoveAll(path) + defer state.Close() + + testCtr1 := getTestContainer("11111111111111111111111111111111", "test1") + testCtr2 := getTestContainer("22222222222222222222222222222222", "test2") + + err = state.AddContainer(testCtr1) + assert.NoError(t, err) + + err = state.AddContainer(testCtr2) + assert.NoError(t, err) + + ctrs, err := state.AllContainers() + assert.NoError(t, err) + assert.Equal(t, 2, len(ctrs)) +} |