package org.jeecg.common.modules.redis.writer; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Function; import lombok.extern.slf4j.Slf4j; import org.springframework.dao.PessimisticLockingFailureException; import org.springframework.data.redis.cache.RedisCacheWriter; import org.springframework.data.redis.connection.RedisConnection; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.connection.RedisStringCommands.SetOption; import org.springframework.data.redis.core.types.Expiration; import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** * 该类参照 DefaultRedisCacheWriter 重写了 remove 方法实现通配符*删除 */ @Slf4j public class JeecgRedisCacheWriter implements RedisCacheWriter { private final RedisConnectionFactory connectionFactory; private final Duration sleepTime; public JeecgRedisCacheWriter(RedisConnectionFactory connectionFactory) { this(connectionFactory, Duration.ZERO); } public JeecgRedisCacheWriter(RedisConnectionFactory connectionFactory, Duration sleepTime) { Assert.notNull(connectionFactory, "ConnectionFactory must not be null!"); Assert.notNull(sleepTime, "SleepTime must not be null!"); this.connectionFactory = connectionFactory; this.sleepTime = sleepTime; } public void put(String name, byte[] key, byte[] value, @Nullable Duration ttl) { Assert.notNull(name, "Name must not be null!"); Assert.notNull(key, "Key must not be null!"); Assert.notNull(value, "Value must not be null!"); this.execute(name, (connection) -> { if (shouldExpireWithin(ttl)) { connection.set(key, value, Expiration.from(ttl.toMillis(), TimeUnit.MILLISECONDS), SetOption.upsert()); } else { connection.set(key, value); } return "OK"; }); } public byte[] get(String name, byte[] key) { Assert.notNull(name, "Name must not be null!"); Assert.notNull(key, "Key must not be null!"); return (byte[])this.execute(name, (connection) -> { return connection.get(key); }); } public byte[] putIfAbsent(String name, byte[] key, byte[] value, @Nullable Duration ttl) { Assert.notNull(name, "Name must not be null!"); Assert.notNull(key, "Key must not be null!"); Assert.notNull(value, "Value must not be null!"); return (byte[])this.execute(name, (connection) -> { if (this.isLockingCacheWriter()) { this.doLock(name, connection); } Object var7; try { boolean put; if (shouldExpireWithin(ttl)) { put = connection.set(key, value, Expiration.from(ttl), SetOption.ifAbsent()); } else { put = connection.setNX(key, value); } if (!put) { byte[] var11 = connection.get(key); return var11; } var7 = null; } finally { if (this.isLockingCacheWriter()) { this.doUnlock(name, connection); } } return (byte[])var7; }); } public void remove(String name, byte[] key) { Assert.notNull(name, "Name must not be null!"); Assert.notNull(key, "Key must not be null!"); String keyString = new String(key); log.info("redis remove key:" + keyString); if(keyString!=null && keyString.endsWith("*")){ execute(name, connection -> { // 获取某个前缀所拥有的所有的键,某个前缀开头,后面肯定是* Set keys = connection.keys(key); int delNum = 0; for (byte[] keyByte : keys) { delNum += connection.del(keyByte); } return delNum; }); }else{ this.execute(name, (connection) -> { return connection.del(new byte[][]{key}); }); } } public void clean(String name, byte[] pattern) { Assert.notNull(name, "Name must not be null!"); Assert.notNull(pattern, "Pattern must not be null!"); this.execute(name, (connection) -> { boolean wasLocked = false; try { if (this.isLockingCacheWriter()) { this.doLock(name, connection); wasLocked = true; } byte[][] keys = (byte[][])((Set)Optional.ofNullable(connection.keys(pattern)).orElse(Collections.emptySet())).toArray(new byte[0][]); if (keys.length > 0) { connection.del(keys); } } finally { if (wasLocked && this.isLockingCacheWriter()) { this.doUnlock(name, connection); } } return "OK"; }); } void lock(String name) { this.execute(name, (connection) -> { return this.doLock(name, connection); }); } void unlock(String name) { this.executeLockFree((connection) -> { this.doUnlock(name, connection); }); } private Boolean doLock(String name, RedisConnection connection) { return connection.setNX(createCacheLockKey(name), new byte[0]); } private Long doUnlock(String name, RedisConnection connection) { return connection.del(new byte[][]{createCacheLockKey(name)}); } boolean doCheckLock(String name, RedisConnection connection) { return connection.exists(createCacheLockKey(name)); } private boolean isLockingCacheWriter() { return !this.sleepTime.isZero() && !this.sleepTime.isNegative(); } private T execute(String name, Function callback) { RedisConnection connection = this.connectionFactory.getConnection(); try { this.checkAndPotentiallyWaitUntilUnlocked(name, connection); return callback.apply(connection); } finally { connection.close(); } } private void executeLockFree(Consumer callback) { RedisConnection connection = this.connectionFactory.getConnection(); try { callback.accept(connection); } finally { connection.close(); } } private void checkAndPotentiallyWaitUntilUnlocked(String name, RedisConnection connection) { if (this.isLockingCacheWriter()) { try { while(this.doCheckLock(name, connection)) { Thread.sleep(this.sleepTime.toMillis()); } } catch (InterruptedException var4) { Thread.currentThread().interrupt(); throw new PessimisticLockingFailureException(String.format("Interrupted while waiting to unlock cache %s", name), var4); } } } private static boolean shouldExpireWithin(@Nullable Duration ttl) { return ttl != null && !ttl.isZero() && !ttl.isNegative(); } private static byte[] createCacheLockKey(String name) { return (name + "~lock").getBytes(StandardCharsets.UTF_8); } }