Skip to content

Commit 9021b74

Browse files
author
John Campion
committed
Move Attach methods to context
To match EF
1 parent c6ea9a0 commit 9021b74

11 files changed

Lines changed: 298 additions & 280 deletions

src/MongoFramework/IMongoDbContext.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,8 @@ public interface IMongoDbContext
1919

2020
void SaveChanges();
2121
Task SaveChangesAsync(CancellationToken cancellationToken = default(CancellationToken));
22+
23+
void Attach<TEntity>(TEntity entity) where TEntity : class;
24+
void AttachRange<TEntity>(IEnumerable<TEntity> entities) where TEntity : class;
2225
}
2326
}

src/MongoFramework/IMongoDbSet.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ public interface IMongoDbSet<TEntity> : IMongoDbSet, IQueryable<TEntity> where T
2323
TEntity Create();
2424
void Add(TEntity entity);
2525
void AddRange(IEnumerable<TEntity> entities);
26-
void Attach(TEntity entity);
27-
void AttachRange(IEnumerable<TEntity> entities);
2826
void Update(TEntity entity);
2927
void UpdateRange(IEnumerable<TEntity> entities);
3028
void Remove(TEntity entity);
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
namespace MongoFramework
1+
using System.Collections.Generic;
2+
3+
namespace MongoFramework
24
{
3-
public interface IMongoDbTenantContext : IMongoDbContext
4-
{
5+
public interface IMongoDbTenantContext : IMongoDbContext
6+
{
57
string TenantId { get; }
8+
void CheckEntity(IHaveTenantId entity);
9+
void CheckEntities(IEnumerable<IHaveTenantId> entity);
610
}
711
}

src/MongoFramework/MongoDbContext.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Reflection;
1010
using System.Threading;
1111
using System.Threading.Tasks;
12+
using MongoFramework.Utilities;
1213

1314
namespace MongoFramework
1415
{
@@ -107,6 +108,7 @@ public virtual async Task SaveChangesAsync(CancellationToken cancellationToken =
107108
ChangeTracker.CommitChanges();
108109
CommandStaging.CommitChanges();
109110
}
111+
110112
private static async Task InternalSaveChangesAsync<TEntity>(IMongoDbConnection connection, IEnumerable<IWriteCommand> commands, WriteModelOptions options, CancellationToken cancellationToken) where TEntity : class
111113
{
112114
await EntityIndexWriter.ApplyIndexingAsync<TEntity>(connection);
@@ -131,7 +133,30 @@ public IQueryable<TEntity> Query<TEntity>() where TEntity : class
131133
var provider = new MongoFrameworkQueryProvider<TEntity>(Connection);
132134
return new MongoFrameworkQueryable<TEntity>(provider);
133135
}
136+
137+
/// <summary>
138+
/// Marks the entity as unchanged in the change tracker and starts tracking.
139+
/// </summary>
140+
/// <param name="entity"></param>
141+
public virtual void Attach<TEntity>(TEntity entity) where TEntity : class
142+
{
143+
Check.NotNull(entity, nameof(entity));
144+
ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
145+
}
134146

147+
/// <summary>
148+
/// Marks the collection of entities as unchanged in the change tracker and starts tracking.
149+
/// </summary>
150+
/// <param name="entities"></param>
151+
public virtual void AttachRange<TEntity>(IEnumerable<TEntity> entities) where TEntity : class
152+
{
153+
Check.NotNull(entities, nameof(entities));
154+
foreach (var entity in entities)
155+
{
156+
ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
157+
}
158+
}
159+
135160
public void Dispose()
136161
{
137162
Dispose(true);

src/MongoFramework/MongoDbSet.cs

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -129,32 +129,8 @@ public virtual void AddRange(IEnumerable<TEntity> entities)
129129
{
130130
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.Added);
131131
}
132-
}
133-
134-
/// <summary>
135-
/// Marks the entity as unchanged in the change tracker and starts tracking.
136-
/// </summary>
137-
/// <param name="entity"></param>
138-
public virtual void Attach(TEntity entity)
139-
{
140-
Check.NotNull(entity, nameof(entity));
141-
142-
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
143132
}
144-
/// <summary>
145-
/// Marks the collection of entities as unchanged in the change tracker and starts tracking.
146-
/// </summary>
147-
/// <param name="entities"></param>
148-
public virtual void AttachRange(IEnumerable<TEntity> entities)
149-
{
150-
Check.NotNull(entities, nameof(entities));
151-
152-
foreach (var entity in entities)
153-
{
154-
Context.ChangeTracker.SetEntityState(entity, EntityEntryState.NoChanges);
155-
}
156-
}
157-
133+
158134
/// <summary>
159135
/// Marks the entity for updating.
160136
/// </summary>

src/MongoFramework/MongoDbTenantContext.cs

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
using MongoFramework.Infrastructure.Commands;
2-
using MongoFramework.Utilities;
3-
using System;
1+
using System.Collections.Generic;
2+
using MongoFramework.Infrastructure;
3+
using MongoFramework.Infrastructure.Commands;
4+
using MongoFramework.Utilities;
45

56
namespace MongoFramework
67
{
@@ -23,5 +24,52 @@ protected override WriteModelOptions GetWriteModelOptions()
2324
{
2425
return new WriteModelOptions { TenantId = TenantId };
2526
}
27+
28+
public virtual void CheckEntity(IHaveTenantId entity)
29+
{
30+
Check.NotNull(entity, nameof(entity));
31+
32+
if (entity.TenantId != TenantId)
33+
{
34+
throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {TenantId}, Entity has: {entity.TenantId}");
35+
}
36+
}
37+
38+
public virtual void CheckEntities(IEnumerable<IHaveTenantId> entities)
39+
{
40+
Check.NotNull(entities, nameof(entities));
41+
42+
foreach (var entity in entities)
43+
{
44+
CheckEntity(entity);
45+
}
46+
}
47+
48+
/// <summary>
49+
/// Marks the entity as unchanged in the change tracker and starts tracking.
50+
/// </summary>
51+
/// <param name="entity"></param>
52+
public override void Attach<TEntity>(TEntity entity) where TEntity : class
53+
{
54+
if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity)))
55+
{
56+
CheckEntity(entity as IHaveTenantId);
57+
}
58+
base.Attach(entity);
59+
}
60+
61+
/// <summary>
62+
/// Marks the collection of entities as unchanged in the change tracker and starts tracking.
63+
/// </summary>
64+
/// <param name="entities"></param>
65+
public override void AttachRange<TEntity>(IEnumerable<TEntity> entities) where TEntity : class
66+
{
67+
if (typeof(IHaveTenantId).IsAssignableFrom(typeof(TEntity)))
68+
{
69+
CheckEntities(entities as IEnumerable<IHaveTenantId>);
70+
}
71+
base.AttachRange(entities);
72+
}
73+
2674
}
2775
}

src/MongoFramework/MongoDbTenantSet.cs

Lines changed: 17 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,8 @@ public MongoDbTenantSet(IMongoDbContext context) : base(context)
2727
{
2828
Context = context as IMongoDbTenantContext ?? throw new ArgumentException("Context provided to a MongoDbTenantSet must be IMongoDbTenantContext",nameof(context));
2929
}
30-
31-
protected virtual void CheckEntity(TEntity entity)
32-
{
33-
Check.NotNull(entity, nameof(entity));
34-
35-
if (entity.TenantId != Context.TenantId)
36-
{
37-
throw new MultiTenantException($"Entity type {entity.GetType().Name}, tenant ID does not match. Expected: {Context.TenantId}, Entity has: {entity.TenantId}");
38-
}
39-
}
40-
41-
protected virtual void CheckEntities(IEnumerable<TEntity> entities)
42-
{
43-
Check.NotNull(entities, nameof(entities));
44-
45-
foreach (var entity in entities)
46-
{
47-
CheckEntity(entity);
48-
}
49-
}
50-
51-
/// <summary>
30+
31+
/// <summary>
5232
/// Finds an entity with the given primary key value. If an entity with the given primary key value
5333
/// is being tracked by the context, then it is returned immediately without making a request to the
5434
/// database. Otherwise, a query is made to the database for an entity with the given primary key value
@@ -138,41 +118,29 @@ public override void AddRange(IEnumerable<TEntity> entities)
138118
entity.TenantId = Context.TenantId;
139119
}
140120
base.AddRange(entities);
141-
}
142-
143-
public override void Attach(TEntity entity)
144-
{
145-
CheckEntity(entity);
146-
base.Attach(entity);
147-
}
148-
149-
public override void AttachRange(IEnumerable<TEntity> entities)
150-
{
151-
CheckEntities(entities);
152-
base.AttachRange(entities);
153-
}
154-
121+
}
122+
155123
public override void Update(TEntity entity)
156124
{
157-
CheckEntity(entity);
125+
Context.CheckEntity(entity);
158126
base.Update(entity);
159127
}
160128

161129
public override void UpdateRange(IEnumerable<TEntity> entities)
162130
{
163-
CheckEntities(entities);
131+
Context.CheckEntities(entities);
164132
base.UpdateRange(entities);
165133
}
166134

167135
public override void Remove(TEntity entity)
168136
{
169-
CheckEntity(entity);
137+
Context.CheckEntity(entity);
170138
base.Remove(entity);
171139
}
172140

173141
public override void RemoveRange(IEnumerable<TEntity> entities)
174142
{
175-
CheckEntities(entities);
143+
Context.CheckEntities(entities);
176144
base.RemoveRange(entities);
177145
}
178146

@@ -189,23 +157,23 @@ protected override IQueryable<TEntity> GetQueryable(bool trackEntities)
189157
{
190158
var key = Context.TenantId;
191159
var queryable = Context.Query<TEntity>().Where(c => c.TenantId == key);
192-
if (trackEntities)
193-
{
194-
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
195-
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
160+
if (trackEntities)
161+
{
162+
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
163+
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
196164
}
197165
return queryable;
198166
}
199167

200168
public IQueryable<TEntity> GetSearchTextQueryable(string search)
201169
{
202170
var key = Context.TenantId;
203-
var queryable = Context.Query<TEntity>().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key);
204-
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
205-
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
171+
var queryable = Context.Query<TEntity>().WhereFilter(b => b.Text(search)).Where(c => c.TenantId == key);
172+
var provider = queryable.Provider as IMongoFrameworkQueryProvider<TEntity>;
173+
provider.EntityProcessors.Add(new EntityTrackingProcessor<TEntity>(Context));
206174
return queryable;
207-
}
208-
175+
}
176+
209177
#endregion
210178
}
211179
}

0 commit comments

Comments
 (0)