diff --git a/server/gorm/client.go b/server/gorm/client.go index a043152ca..8ed01d38d 100644 --- a/server/gorm/client.go +++ b/server/gorm/client.go @@ -79,7 +79,6 @@ func NewClient(ctx context.Context, gormDBName, gormConfig string) (*Client, err mylock() clientCount++ clientTotal++ - //log.Printf("CREATING CLIENT (%d/%d %d)", clientCount, clientTotal, getGID()) switch gormDBName { case "sqlite3": db, err := gorm.Open(sqlite.Open(gormConfig), config()) @@ -128,7 +127,6 @@ func (c *Client) Close() { } func (c *Client) close() { - // log.Printf("CLOSING (%d) %d", clientCount, getGID()) clientCount-- sqlDB, _ := c.db.DB() sqlDB.Close() @@ -213,7 +211,8 @@ func (c *Client) Put(ctx context.Context, k storage.Key, v interface{}) (storage } c.db.Transaction( func(tx *gorm.DB) error { - rowsAffected := tx.Model(v).Where("key = ?", k.(*Key).Name).Updates(v).RowsAffected + // Update all fields from model: https://gorm.io/docs/update.html#Update-Selected-Fields + rowsAffected := tx.Model(v).Select("*").Where("key = ?", k.(*Key).Name).Updates(v).RowsAffected if rowsAffected == 0 { err := tx.Create(v).Error if err != nil { diff --git a/server/gorm/gorm_test.go b/server/gorm/gorm_test.go index a7243806d..90cd4f309 100644 --- a/server/gorm/gorm_test.go +++ b/server/gorm/gorm_test.go @@ -22,8 +22,49 @@ import ( "github.com/apigee/registry/server/models" "github.com/apigee/registry/server/storage" + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" ) +func TestFieldClearing(t *testing.T) { + ctx := context.TODO() + + c, err := NewClient(ctx, "sqlite3", "/tmp/testing.db") + if err != nil { + t.Fatalf("NewClient returned error: %s", err) + } + defer c.Close() + c.reset() + + original := &models.Project{ + ProjectID: "my-project", + Description: "My Project", + } + + k := c.NewKey(storage.ProjectEntityName, original.Name()) + if _, err := c.Put(ctx, k, original); err != nil { + t.Fatalf("Setup: Put(%q, %+v) returned error: %s", k, original, err) + } + + update := &models.Project{ + ProjectID: original.ProjectID, + Description: "", + } + + if _, err := c.Put(ctx, k, update); err != nil { + t.Fatalf("Put(%q, %+v) returned error: %s", k, update, err) + } + + got := new(models.Project) + if err := c.Get(ctx, k, got); err != nil { + t.Fatalf("Get(%q) returned error: %s", k, err) + } + + if !cmp.Equal(got, update, protocmp.Transform()) { + t.Errorf("Get(%q) returned unexpected diff (-want +got):\n%s", k, cmp.Diff(update, got, protocmp.Transform())) + } +} + func TestCRUD(t *testing.T) { ctx := context.TODO()