Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: P2P ACP #3317

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions acp/identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/cyware/ssi-sdk/did/key"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/sourcenetwork/immutable"
acptypes "github.com/sourcenetwork/sourcehub/x/acp/bearer_token"
Expand Down Expand Up @@ -127,6 +128,28 @@ func (identity *Identity) UpdateToken(
audience immutable.Option[string],
authorizedAccount immutable.Option[string],
) error {
signedToken, err := identity.NewToken(duration, audience, authorizedAccount)
if err != nil {
return err
}

identity.BearerToken = string(signedToken)
return nil
}

// NewToken creates and returns a new `BearerToken`.
//
// - duration: The [time.Duration] that this identity is valid for.
// - audience: The audience that this identity is valid for. This is required
// by the Defra http client. For example `github.com/sourcenetwork/defradb`
// - authorizedAccount: An account that this identity is authorizing to make
// SourceHub calls on behalf of this actor. This is currently required when
// using SourceHub ACP.
func (identity Identity) NewToken(
duration time.Duration,
audience immutable.Option[string],
authorizedAccount immutable.Option[string],
) ([]byte, error) {
var signedToken []byte
subject := hex.EncodeToString(identity.PublicKey.SerializeCompressed())
now := time.Now()
Expand All @@ -144,21 +167,39 @@ func (identity *Identity) UpdateToken(

token, err := jwtBuilder.Build()
if err != nil {
return err
return nil, err
}

if authorizedAccount.HasValue() {
err = token.Set(acptypes.AuthorizedAccountClaim, authorizedAccount.Value())
if err != nil {
return err
return nil, err
}
}

signedToken, err = jwt.Sign(token, jwt.WithKey(BearerTokenSignatureScheme, identity.PrivateKey.ToECDSA()))
if err != nil {
return nil, err
}

return signedToken, nil
}

// VerifyAuthToken verifies that the jwt auth token is valid and that the signature
// matches the identity of the subject.
func VerifyAuthToken(ident Identity, audience string) error {
_, err := jwt.Parse([]byte(ident.BearerToken), jwt.WithVerify(false), jwt.WithAudience(audience))
if err != nil {
return err
}

_, err = jws.Verify(
[]byte(ident.BearerToken),
jws.WithKey(BearerTokenSignatureScheme, ident.PublicKey.ToECDSA()),
)
if err != nil {
return err
}

identity.BearerToken = string(signedToken)
return nil
}
22 changes: 1 addition & 21 deletions http/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"net/http"
"strings"

"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/sourcenetwork/immutable"

acpIdentity "github.com/sourcenetwork/defradb/acp/identity"
Expand All @@ -30,24 +28,6 @@ const (
authSchemaPrefix = "Bearer "
)

// verifyAuthToken verifies that the jwt auth token is valid and that the signature
// matches the identity of the subject.
func verifyAuthToken(identity acpIdentity.Identity, audience string) error {
_, err := jwt.Parse([]byte(identity.BearerToken), jwt.WithVerify(false), jwt.WithAudience(audience))
if err != nil {
return err
}

_, err = jws.Verify(
[]byte(identity.BearerToken),
jws.WithKey(acpIdentity.BearerTokenSignatureScheme, identity.PublicKey.ToECDSA()),
)
if err != nil {
return err
}
return nil
}

// AuthMiddleware authenticates an actor and sets their identity for all subsequent actions.
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
Expand All @@ -63,7 +43,7 @@ func AuthMiddleware(next http.Handler) http.Handler {
return
}

err = verifyAuthToken(ident, strings.ToLower(req.Host))
err = acpIdentity.VerifyAuthToken(ident, strings.ToLower(req.Host))
if err != nil {
http.Error(rw, "forbidden", http.StatusForbidden)
return
Expand Down
6 changes: 3 additions & 3 deletions http/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestVerifyAuthToken(t *testing.T) {
err = identity.UpdateToken(time.Hour, immutable.Some(audience), immutable.None[string]())
require.NoError(t, err)

err = verifyAuthToken(identity, audience)
err = acpIdentity.VerifyAuthToken(identity, audience)
require.NoError(t, err)
}

Expand All @@ -48,7 +48,7 @@ func TestVerifyAuthTokenErrorsWithNonMatchingAudience(t *testing.T) {
err = identity.UpdateToken(time.Hour, immutable.Some("valid"), immutable.None[string]())
require.NoError(t, err)

err = verifyAuthToken(identity, "invalid")
err = acpIdentity.VerifyAuthToken(identity, "invalid")
assert.Error(t, err)
}

Expand All @@ -65,6 +65,6 @@ func TestVerifyAuthTokenErrorsWithExpired(t *testing.T) {
err = identity.UpdateToken(-time.Hour, immutable.Some(audience), immutable.None[string]())
require.NoError(t, err)

err = verifyAuthToken(identity, "123abc")
err = acpIdentity.VerifyAuthToken(identity, "123abc")
assert.Error(t, err)
}
4 changes: 2 additions & 2 deletions internal/db/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/sourcenetwork/defradb/client/request"
)

func (db *db) basicImport(ctx context.Context, filepath string) (err error) {
func (db *DB) basicImport(ctx context.Context, filepath string) (err error) {
f, err := os.Open(filepath)
if err != nil {
return NewErrOpenFile(err, filepath)
Expand Down Expand Up @@ -115,7 +115,7 @@ func (db *db) basicImport(ctx context.Context, filepath string) (err error) {
return nil
}

func (db *db) basicExport(ctx context.Context, config *client.BackupConfig) (err error) {
func (db *DB) basicExport(ctx context.Context, config *client.BackupConfig) (err error) {
// old key -> new Key
keyChangeCache := map[string]string{}

Expand Down
12 changes: 6 additions & 6 deletions internal/db/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ var _ client.Collection = (*collection)(nil)
// collection stores data records at Documents, which are gathered
// together under a collection name. This is analogous to SQL Tables.
type collection struct {
db *db
db *DB
def client.CollectionDefinition
indexes []CollectionIndex
fetcherFactory func() fetcher.Fetcher
Expand All @@ -55,7 +55,7 @@ type collection struct {
// CollectionOptions object.

// newCollection returns a pointer to a newly instantiated DB Collection
func (db *db) newCollection(desc client.CollectionDescription, schema client.SchemaDescription) *collection {
func (db *DB) newCollection(desc client.CollectionDescription, schema client.SchemaDescription) *collection {
return &collection{
db: db,
def: client.CollectionDefinition{Description: desc, Schema: schema},
Expand All @@ -77,7 +77,7 @@ func (c *collection) newFetcher() fetcher.Fetcher {
return lens.NewFetcher(innerFetcher, c.db.LensRegistry())
}

func (db *db) getCollectionByID(ctx context.Context, id uint32) (client.Collection, error) {
func (db *DB) getCollectionByID(ctx context.Context, id uint32) (client.Collection, error) {
txn := mustGetContextTxn(ctx)

col, err := description.GetCollectionByID(ctx, txn, id)
Expand All @@ -101,7 +101,7 @@ func (db *db) getCollectionByID(ctx context.Context, id uint32) (client.Collecti
}

// getCollectionByName returns an existing collection within the database.
func (db *db) getCollectionByName(ctx context.Context, name string) (client.Collection, error) {
func (db *DB) getCollectionByName(ctx context.Context, name string) (client.Collection, error) {
if name == "" {
return nil, ErrCollectionNameEmpty
}
Expand All @@ -120,7 +120,7 @@ func (db *db) getCollectionByName(ctx context.Context, name string) (client.Coll
//
// Inactive collections are not returned by default unless a specific schema version ID
// is provided.
func (db *db) getCollections(
func (db *DB) getCollections(
ctx context.Context,
options client.CollectionFetchOptions,
) ([]client.Collection, error) {
Expand Down Expand Up @@ -219,7 +219,7 @@ func (db *db) getCollections(
}

// getAllActiveDefinitions returns all queryable collection/views and any embedded schema used by them.
func (db *db) getAllActiveDefinitions(ctx context.Context) ([]client.CollectionDefinition, error) {
func (db *DB) getAllActiveDefinitions(ctx context.Context) ([]client.CollectionDefinition, error) {
txn := mustGetContextTxn(ctx)

cols, err := description.GetActiveCollections(ctx, txn)
Expand Down
10 changes: 5 additions & 5 deletions internal/db/collection_define.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
"github.com/sourcenetwork/defradb/internal/db/description"
)

func (db *db) createCollections(
func (db *DB) createCollections(
ctx context.Context,
newDefinitions []client.CollectionDefinition,
) ([]client.CollectionDefinition, error) {
Expand Down Expand Up @@ -112,7 +112,7 @@ func (db *db) createCollections(
return returnDescriptions, nil
}

func (db *db) patchCollection(
func (db *DB) patchCollection(
ctx context.Context,
patchString string,
) error {
Expand Down Expand Up @@ -224,7 +224,7 @@ func (db *db) patchCollection(
// provided. This includes GQL queries and Collection operations.
//
// It will return an error if the provided schema version ID does not exist.
func (db *db) setActiveSchemaVersion(
func (db *DB) setActiveSchemaVersion(
ctx context.Context,
schemaVersionID string,
) error {
Expand Down Expand Up @@ -311,7 +311,7 @@ func (db *db) setActiveSchemaVersion(
return db.loadSchema(ctx)
}

func (db *db) getActiveCollectionDown(
func (db *DB) getActiveCollectionDown(
ctx context.Context,
colsByID map[uint32]client.CollectionDescription,
id uint32,
Expand All @@ -338,7 +338,7 @@ func (db *db) getActiveCollectionDown(
return db.getActiveCollectionDown(ctx, colsByID, sources[0].SourceCollectionID)
}

func (db *db) getActiveCollectionUp(
func (db *DB) getActiveCollectionUp(
ctx context.Context,
colsBySourceID map[uint32][]client.CollectionDescription,
id uint32,
Expand Down
6 changes: 3 additions & 3 deletions internal/db/collection_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
)

// setCollectionIDs sets the IDs on a collection description, including field IDs, mutating the input set.
func (db *db) setCollectionIDs(ctx context.Context, newCollections []client.CollectionDefinition) error {
func (db *DB) setCollectionIDs(ctx context.Context, newCollections []client.CollectionDefinition) error {
err := db.setCollectionID(ctx, newCollections)
if err != nil {
return err
Expand All @@ -32,7 +32,7 @@ func (db *db) setCollectionIDs(ctx context.Context, newCollections []client.Coll

// setCollectionID sets the IDs directly on a collection description, excluding stuff like field IDs,
// mutating the input set.
func (db *db) setCollectionID(ctx context.Context, newCollections []client.CollectionDefinition) error {
func (db *DB) setCollectionID(ctx context.Context, newCollections []client.CollectionDefinition) error {
colSeq, err := db.getSequence(ctx, keys.CollectionIDSequenceKey{})
if err != nil {
return err
Expand Down Expand Up @@ -64,7 +64,7 @@ func (db *db) setCollectionID(ctx context.Context, newCollections []client.Colle
}

// setFieldIDs sets the field IDs hosted on the given collections, mutating the input set.
func (db *db) setFieldIDs(ctx context.Context, definitions []client.CollectionDefinition) error {
func (db *DB) setFieldIDs(ctx context.Context, definitions []client.CollectionDefinition) error {
collectionsByName := map[string]client.CollectionDescription{}
schemasByName := map[string]client.SchemaDescription{}
for _, def := range definitions {
Expand Down
8 changes: 4 additions & 4 deletions internal/db/collection_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
)

// createCollectionIndex creates a new collection index and saves it to the database in its system store.
func (db *db) createCollectionIndex(
func (db *DB) createCollectionIndex(
ctx context.Context,
collectionName string,
desc client.IndexDescriptionCreateRequest,
Expand All @@ -44,7 +44,7 @@ func (db *db) createCollectionIndex(
return col.CreateIndex(ctx, desc)
}

func (db *db) dropCollectionIndex(
func (db *DB) dropCollectionIndex(
ctx context.Context,
collectionName, indexName string,
) error {
Expand All @@ -56,7 +56,7 @@ func (db *db) dropCollectionIndex(
}

// getAllIndexDescriptions returns all the index descriptions in the database.
func (db *db) getAllIndexDescriptions(
func (db *DB) getAllIndexDescriptions(
ctx context.Context,
) (map[client.CollectionName][]client.IndexDescription, error) {
// callers of this function must set a context transaction
Expand Down Expand Up @@ -92,7 +92,7 @@ func (db *db) getAllIndexDescriptions(
return indexes, nil
}

func (db *db) fetchCollectionIndexDescriptions(
func (db *DB) fetchCollectionIndexDescriptions(
ctx context.Context,
colID uint32,
) ([]client.IndexDescription, error) {
Expand Down
Loading
Loading