正在显示
26 个修改的文件
包含
868 行增加
和
117 行删除
lh-common/ruoyi-common-security/src/main/java/com/zhonglai/luhui/security/dto/LoginToken.java
0 → 100644
| 1 | +package com.zhonglai.luhui.security.dto; | ||
| 2 | + | ||
| 3 | +import com.google.gson.JsonObject; | ||
| 4 | +import com.ruoyi.common.utils.DESUtil; | ||
| 5 | +import com.ruoyi.common.utils.DateUtils; | ||
| 6 | +import com.ruoyi.common.utils.GsonConstructor; | ||
| 7 | +import lombok.Data; | ||
| 8 | + | ||
| 9 | +import java.io.Serializable; | ||
| 10 | + | ||
| 11 | +/** | ||
| 12 | + * 登录令牌 | ||
| 13 | + */ | ||
| 14 | +@Data | ||
| 15 | +public class LoginToken implements Serializable { | ||
| 16 | + private static final long serialVersionUID = -8696564127500370479L; | ||
| 17 | + | ||
| 18 | + private Integer userId; //当前用户id | ||
| 19 | + private String userLoginName; //当前用户登录名 | ||
| 20 | + private String userNickName; //当前用户昵称 | ||
| 21 | + private Integer createTime; //生成时间 | ||
| 22 | + private String key = "LiuYuLeXX"; //密钥 | ||
| 23 | + private String userType; //用户类型(0普通用户,1管理员) | ||
| 24 | + | ||
| 25 | + public LoginToken(Integer userId, String userLoginName, String userNickName, String userType) { | ||
| 26 | + this.userId = userId; | ||
| 27 | + this.userLoginName = userLoginName; | ||
| 28 | + this.userNickName = userNickName; | ||
| 29 | + this.userType = userType; | ||
| 30 | + this.createTime = DateUtils.getNowTimeMilly(); | ||
| 31 | + } | ||
| 32 | + | ||
| 33 | + /** | ||
| 34 | + * 解密token | ||
| 35 | + * @param deLoginToken | ||
| 36 | + */ | ||
| 37 | + public LoginToken(String deLoginToken) | ||
| 38 | + { | ||
| 39 | + String loginTokenString = DESUtil.decode(deLoginToken,key); | ||
| 40 | + JsonObject jsonObject = GsonConstructor.get().fromJson(loginTokenString, JsonObject.class); | ||
| 41 | + if(jsonObject.has("userId")) | ||
| 42 | + { | ||
| 43 | + userId = jsonObject.get("userId").getAsInt(); | ||
| 44 | + } | ||
| 45 | + if(jsonObject.has("userLoginName")) | ||
| 46 | + { | ||
| 47 | + userLoginName = jsonObject.get("userLoginName").getAsString(); | ||
| 48 | + } | ||
| 49 | + if(jsonObject.has("userNickName")) | ||
| 50 | + { | ||
| 51 | + userNickName = jsonObject.get("userNickName").getAsString(); | ||
| 52 | + } | ||
| 53 | + if(jsonObject.has("createTime")) | ||
| 54 | + { | ||
| 55 | + createTime = jsonObject.get("createTime").getAsInt(); | ||
| 56 | + } | ||
| 57 | + if(jsonObject.has("userType")) | ||
| 58 | + { | ||
| 59 | + userType = jsonObject.get("userType").getAsString(); | ||
| 60 | + } | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + /** | ||
| 64 | + * 生成加密loginToken | ||
| 65 | + * @return | ||
| 66 | + */ | ||
| 67 | + public String get() | ||
| 68 | + { | ||
| 69 | + return DESUtil.encode(GsonConstructor.get().toJson(this),key); | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | +} |
| @@ -22,24 +22,24 @@ import java.io.IOException; | @@ -22,24 +22,24 @@ import java.io.IOException; | ||
| 22 | * | 22 | * |
| 23 | * @author ruoyi | 23 | * @author ruoyi |
| 24 | */ | 24 | */ |
| 25 | -@Component | ||
| 26 | -public class JwtAuthenticationTokenFilter extends OncePerRequestFilter | 25 | +public abstract class JwtAuthenticationTokenFilter extends OncePerRequestFilter |
| 27 | { | 26 | { |
| 28 | - @Autowired | ||
| 29 | - private TokenService tokenService; | ||
| 30 | - | ||
| 31 | @Override | 27 | @Override |
| 32 | protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) | 28 | protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain) |
| 33 | throws ServletException, IOException | 29 | throws ServletException, IOException |
| 34 | { | 30 | { |
| 35 | - BaseLoginUser loginUser = tokenService.getLoginUser(request); | ||
| 36 | - if (StringUtils.isNotNull(loginUser) && StringUtils.isNull(SecurityUtils.getAuthentication())) | 31 | + BaseLoginUser loginUser = getBaseLoginUser(request); |
| 32 | + if(verifyToken(loginUser)) | ||
| 37 | { | 33 | { |
| 38 | - tokenService.verifyToken(loginUser); | ||
| 39 | - UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken(loginUser, null, loginUser.getAuthorities()); | 34 | + UsernamePasswordAuthenticationToken authenticationToken = getUsernamePasswordAuthenticationToken(loginUser); |
| 40 | authenticationToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); | 35 | authenticationToken.setDetails(new WebAuthenticationDetailsSource().buildDetails(request)); |
| 41 | SecurityContextHolder.getContext().setAuthentication(authenticationToken); | 36 | SecurityContextHolder.getContext().setAuthentication(authenticationToken); |
| 37 | + chain.doFilter(request, response); | ||
| 42 | } | 38 | } |
| 43 | chain.doFilter(request, response); | 39 | chain.doFilter(request, response); |
| 44 | } | 40 | } |
| 41 | + public abstract BaseLoginUser getBaseLoginUser(HttpServletRequest request); | ||
| 42 | + public abstract boolean verifyToken(BaseLoginUser loginUser); | ||
| 43 | + | ||
| 44 | + public abstract UsernamePasswordAuthenticationToken getUsernamePasswordAuthenticationToken( BaseLoginUser loginUser); | ||
| 45 | } | 45 | } |
| 1 | package com.zhonglai.luhui.security.service; | 1 | package com.zhonglai.luhui.security.service; |
| 2 | 2 | ||
| 3 | -import com.ruoyi.common.tool.SysLogininforType; | ||
| 4 | import com.zhonglai.luhui.security.filter.JwtAuthenticationTokenFilter; | 3 | import com.zhonglai.luhui.security.filter.JwtAuthenticationTokenFilter; |
| 5 | import com.zhonglai.luhui.security.handle.AuthenticationEntryPointImpl; | 4 | import com.zhonglai.luhui.security.handle.AuthenticationEntryPointImpl; |
| 6 | import com.zhonglai.luhui.security.handle.LogoutSuccessHandlerImpl; | 5 | import com.zhonglai.luhui.security.handle.LogoutSuccessHandlerImpl; |
| @@ -10,16 +9,12 @@ import org.springframework.context.annotation.Bean; | @@ -10,16 +9,12 @@ import org.springframework.context.annotation.Bean; | ||
| 10 | import org.springframework.http.HttpMethod; | 9 | import org.springframework.http.HttpMethod; |
| 11 | import org.springframework.security.config.annotation.web.builders.HttpSecurity; | 10 | import org.springframework.security.config.annotation.web.builders.HttpSecurity; |
| 12 | import org.springframework.security.config.http.SessionCreationPolicy; | 11 | import org.springframework.security.config.http.SessionCreationPolicy; |
| 13 | -import org.springframework.security.core.context.SecurityContextHolder; | ||
| 14 | -import org.springframework.security.core.userdetails.UserDetailsService; | ||
| 15 | import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; | 12 | import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; |
| 16 | import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; | 13 | import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; |
| 17 | import org.springframework.security.web.authentication.logout.LogoutFilter; | 14 | import org.springframework.security.web.authentication.logout.LogoutFilter; |
| 18 | import org.springframework.stereotype.Service; | 15 | import org.springframework.stereotype.Service; |
| 19 | import org.springframework.web.filter.CorsFilter; | 16 | import org.springframework.web.filter.CorsFilter; |
| 20 | 17 | ||
| 21 | -import javax.servlet.*; | ||
| 22 | -import java.io.IOException; | ||
| 23 | 18 | ||
| 24 | @Service | 19 | @Service |
| 25 | public class SecurityConfigService { | 20 | public class SecurityConfigService { |
| @@ -41,7 +36,6 @@ public class SecurityConfigService { | @@ -41,7 +36,6 @@ public class SecurityConfigService { | ||
| 41 | @Autowired | 36 | @Autowired |
| 42 | private CorsFilter corsFilter; | 37 | private CorsFilter corsFilter; |
| 43 | 38 | ||
| 44 | - | ||
| 45 | /** | 39 | /** |
| 46 | * 退出处理类 | 40 | * 退出处理类 |
| 47 | */ | 41 | */ |
| @@ -60,7 +54,7 @@ public class SecurityConfigService { | @@ -60,7 +54,7 @@ public class SecurityConfigService { | ||
| 60 | // 添加JWT filter | 54 | // 添加JWT filter |
| 61 | httpSecurity.addFilterBefore(authenticationTokenFilter, UsernamePasswordAuthenticationFilter.class); | 55 | httpSecurity.addFilterBefore(authenticationTokenFilter, UsernamePasswordAuthenticationFilter.class); |
| 62 | // 添加CORS filter | 56 | // 添加CORS filter |
| 63 | - httpSecurity.addFilterBefore(corsFilter, JwtAuthenticationTokenFilter.class); | 57 | + httpSecurity.addFilterBefore(corsFilter, authenticationTokenFilter.getClass()); |
| 64 | httpSecurity.addFilterBefore(corsFilter, LogoutFilter.class); | 58 | httpSecurity.addFilterBefore(corsFilter, LogoutFilter.class); |
| 65 | } | 59 | } |
| 66 | 60 |
| @@ -12,6 +12,7 @@ import org.springframework.stereotype.Controller; | @@ -12,6 +12,7 @@ import org.springframework.stereotype.Controller; | ||
| 12 | import org.springframework.web.bind.annotation.*; | 12 | import org.springframework.web.bind.annotation.*; |
| 13 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | 13 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| 14 | 14 | ||
| 15 | +import javax.servlet.http.HttpServletRequest; | ||
| 15 | import javax.servlet.http.HttpServletResponse; | 16 | import javax.servlet.http.HttpServletResponse; |
| 16 | import java.util.Map; | 17 | import java.util.Map; |
| 17 | 18 | ||
| @@ -53,9 +54,9 @@ public class ChatController { | @@ -53,9 +54,9 @@ public class ChatController { | ||
| 53 | @CrossOrigin | 54 | @CrossOrigin |
| 54 | @PostMapping("/chat") | 55 | @PostMapping("/chat") |
| 55 | @ResponseBody | 56 | @ResponseBody |
| 56 | - public ChatResponse sseChat(@RequestBody ChatRequest chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) { | 57 | + public ChatResponse sseChat(@RequestBody ChatRequest chatRequest, @RequestHeader Map<String, String> headers, HttpServletRequest request) { |
| 57 | String uid = getUid(headers); | 58 | String uid = getUid(headers); |
| 58 | - return sseService.sseChat(true,0,uid, chatRequest, ChatCompletion.Model.GPT_3_5_TURBO_0301,null); | 59 | + return sseService.sseChat(true,0,uid, chatRequest, ChatCompletion.Model.GPT_3_5_TURBO_0301,null,request); |
| 59 | } | 60 | } |
| 60 | 61 | ||
| 61 | /** | 62 | /** |
| @@ -3,6 +3,7 @@ package com.zhonglai.luhui.chatgpt.listener; | @@ -3,6 +3,7 @@ package com.zhonglai.luhui.chatgpt.listener; | ||
| 3 | import com.fasterxml.jackson.databind.ObjectMapper; | 3 | import com.fasterxml.jackson.databind.ObjectMapper; |
| 4 | import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; | 4 | import com.unfbx.chatgpt.entity.chat.ChatCompletionResponse; |
| 5 | import com.unfbx.chatgpt.entity.chat.Message; | 5 | import com.unfbx.chatgpt.entity.chat.Message; |
| 6 | +import com.zhonglai.luhui.chatgpt.entity.MyChatCompletionResponse; | ||
| 6 | import com.zhonglai.luhui.chatgpt.event.MyEvent; | 7 | import com.zhonglai.luhui.chatgpt.event.MyEvent; |
| 7 | import com.zhonglai.luhui.chatgpt.service.CompleteCallback; | 8 | import com.zhonglai.luhui.chatgpt.service.CompleteCallback; |
| 8 | import lombok.SneakyThrows; | 9 | import lombok.SneakyThrows; |
| @@ -14,6 +15,7 @@ import okhttp3.sse.EventSourceListener; | @@ -14,6 +15,7 @@ import okhttp3.sse.EventSourceListener; | ||
| 14 | import org.springframework.http.MediaType; | 15 | import org.springframework.http.MediaType; |
| 15 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | 16 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| 16 | 17 | ||
| 18 | +import javax.servlet.http.HttpServletRequest; | ||
| 17 | import java.net.URLEncoder; | 19 | import java.net.URLEncoder; |
| 18 | import java.util.Objects; | 20 | import java.util.Objects; |
| 19 | 21 | ||
| @@ -34,12 +36,15 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | @@ -34,12 +36,15 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | ||
| 34 | 36 | ||
| 35 | private CompleteCallback completeCallback; | 37 | private CompleteCallback completeCallback; |
| 36 | 38 | ||
| 39 | + private HttpServletRequest request; | ||
| 40 | + | ||
| 37 | private boolean isHaveData; | 41 | private boolean isHaveData; |
| 38 | private int recordId; | 42 | private int recordId; |
| 39 | - public OpenAISSEEventSourceListener(SseEmitter sseEmitter, CompleteCallback completeCallback,int recordId) { | 43 | + public OpenAISSEEventSourceListener(SseEmitter sseEmitter, CompleteCallback completeCallback,int recordId,HttpServletRequest request) { |
| 40 | this.sseEmitter = sseEmitter; | 44 | this.sseEmitter = sseEmitter; |
| 41 | this.completeCallback = completeCallback; | 45 | this.completeCallback = completeCallback; |
| 42 | this.recordId = recordId; | 46 | this.recordId = recordId; |
| 47 | + this.request = request; | ||
| 43 | } | 48 | } |
| 44 | 49 | ||
| 45 | /** | 50 | /** |
| @@ -77,23 +82,16 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | @@ -77,23 +82,16 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | ||
| 77 | return; | 82 | return; |
| 78 | } | 83 | } |
| 79 | ObjectMapper mapper = new ObjectMapper(); | 84 | ObjectMapper mapper = new ObjectMapper(); |
| 80 | - ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); // 读取Json | 85 | + MyChatCompletionResponse completionResponse = mapper.readValue(data, MyChatCompletionResponse.class); // 读取Json |
| 81 | 86 | ||
| 87 | + String content="无法获取返回内容!"; | ||
| 82 | try { | 88 | try { |
| 89 | + if(null !=completionResponse.getWarning() && !completionResponse.getWarning().equals("")) | ||
| 90 | + { | ||
| 83 | Message delta = completionResponse.getChoices().get(0).getDelta(); | 91 | Message delta = completionResponse.getChoices().get(0).getDelta(); |
| 84 | if(null != delta.getContent()) | 92 | if(null != delta.getContent()) |
| 85 | { | 93 | { |
| 86 | - String content = delta.getContent(); | ||
| 87 | - if(isHaveData) | ||
| 88 | - { | ||
| 89 | -// if(content.startsWith("\n") || content.endsWith("\n") ) | ||
| 90 | -// { | ||
| 91 | -// content = content.replace("\n","~n~"); | ||
| 92 | -// } | ||
| 93 | - sseEmitter.send(URLEncoder.encode(content,"utf-8").replaceAll("\\+", "%20"), MediaType.TEXT_EVENT_STREAM); | ||
| 94 | - }else{ | ||
| 95 | - sseEmitter.send(new MyEvent().data(content, MediaType.TEXT_EVENT_STREAM)); | ||
| 96 | - } | 94 | + content = delta.getContent(); |
| 97 | contents.append(content); | 95 | contents.append(content); |
| 98 | } | 96 | } |
| 99 | // sseEmitter.send(SseEmitter.event() | 97 | // sseEmitter.send(SseEmitter.event() |
| @@ -101,11 +99,23 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | @@ -101,11 +99,23 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | ||
| 101 | // .data(delta.getContent()) | 99 | // .data(delta.getContent()) |
| 102 | // .reconnectTime(3000) | 100 | // .reconnectTime(3000) |
| 103 | // ); | 101 | // ); |
| 102 | + }else{ | ||
| 103 | + content = completionResponse.getWarning(); | ||
| 104 | + } | ||
| 105 | + | ||
| 106 | + if(isHaveData) | ||
| 107 | + { | ||
| 108 | + sseEmitter.send(URLEncoder.encode(content,"utf-8").replaceAll("\\+", "%20"), MediaType.TEXT_EVENT_STREAM); | ||
| 109 | + }else{ | ||
| 110 | + sseEmitter.send(new MyEvent().data(content, MediaType.TEXT_EVENT_STREAM)); | ||
| 111 | + } | ||
| 112 | + | ||
| 104 | } catch (Exception e) { | 113 | } catch (Exception e) { |
| 105 | log.error("sse信息推送失败!"); | 114 | log.error("sse信息推送失败!"); |
| 106 | eventSource.cancel(); | 115 | eventSource.cancel(); |
| 107 | e.printStackTrace(); | 116 | e.printStackTrace(); |
| 108 | } | 117 | } |
| 118 | + | ||
| 109 | } | 119 | } |
| 110 | 120 | ||
| 111 | 121 | ||
| @@ -115,7 +125,7 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | @@ -115,7 +125,7 @@ public class OpenAISSEEventSourceListener extends EventSourceListener { | ||
| 115 | log.info("OpenAI关闭sse连接..."); | 125 | log.info("OpenAI关闭sse连接..."); |
| 116 | if(null != completeCallback) | 126 | if(null != completeCallback) |
| 117 | { | 127 | { |
| 118 | - completeCallback.sseChatEnd(recordId,tokens,contents.toString()); | 128 | + completeCallback.sseChatEnd(recordId,tokens,contents.toString(),request); |
| 119 | } | 129 | } |
| 120 | } | 130 | } |
| 121 | 131 |
| @@ -4,7 +4,9 @@ import com.unfbx.chatgpt.entity.chat.ChatCompletion; | @@ -4,7 +4,9 @@ import com.unfbx.chatgpt.entity.chat.ChatCompletion; | ||
| 4 | import com.zhonglai.luhui.chatgpt.controller.request.ChatRequest; | 4 | import com.zhonglai.luhui.chatgpt.controller.request.ChatRequest; |
| 5 | import com.zhonglai.luhui.chatgpt.listener.OpenAISSEEventSourceListener; | 5 | import com.zhonglai.luhui.chatgpt.listener.OpenAISSEEventSourceListener; |
| 6 | 6 | ||
| 7 | +import javax.servlet.http.HttpServletRequest; | ||
| 8 | + | ||
| 7 | public interface CompleteCallback { | 9 | public interface CompleteCallback { |
| 8 | - void sseChatEnd(int recordId,long tokens,String contents); | 10 | + void sseChatEnd(int recordId, long tokens, String contents, HttpServletRequest httpServletRequest); |
| 9 | int recordSseChat(Integer user_id, ChatRequest chatRequest, ChatCompletion chatCompletion); | 11 | int recordSseChat(Integer user_id, ChatRequest chatRequest, ChatCompletion chatCompletion); |
| 10 | } | 12 | } |
| @@ -5,6 +5,8 @@ import com.zhonglai.luhui.chatgpt.controller.request.ChatRequest; | @@ -5,6 +5,8 @@ import com.zhonglai.luhui.chatgpt.controller.request.ChatRequest; | ||
| 5 | import com.zhonglai.luhui.chatgpt.controller.response.ChatResponse; | 5 | import com.zhonglai.luhui.chatgpt.controller.response.ChatResponse; |
| 6 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | 6 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| 7 | 7 | ||
| 8 | +import javax.servlet.http.HttpServletRequest; | ||
| 9 | + | ||
| 8 | /** | 10 | /** |
| 9 | * 描述: | 11 | * 描述: |
| 10 | * | 12 | * |
| @@ -30,5 +32,5 @@ public interface SseService { | @@ -30,5 +32,5 @@ public interface SseService { | ||
| 30 | * @param uid | 32 | * @param uid |
| 31 | * @param chatRequest | 33 | * @param chatRequest |
| 32 | */ | 34 | */ |
| 33 | - ChatResponse sseChat(Boolean isHaveData,Integer user_id,String uid, ChatRequest chatRequest, ChatCompletion.Model model, CompleteCallback completeCallback); | 35 | + ChatResponse sseChat(Boolean isHaveData, Integer user_id, String uid, ChatRequest chatRequest, ChatCompletion.Model model, CompleteCallback completeCallback, HttpServletRequest request); |
| 34 | } | 36 | } |
| @@ -17,6 +17,7 @@ import lombok.extern.slf4j.Slf4j; | @@ -17,6 +17,7 @@ import lombok.extern.slf4j.Slf4j; | ||
| 17 | import org.springframework.stereotype.Service; | 17 | import org.springframework.stereotype.Service; |
| 18 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | 18 | import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; |
| 19 | 19 | ||
| 20 | +import javax.servlet.http.HttpServletRequest; | ||
| 20 | import java.io.IOException; | 21 | import java.io.IOException; |
| 21 | import java.util.ArrayList; | 22 | import java.util.ArrayList; |
| 22 | import java.util.List; | 23 | import java.util.List; |
| @@ -83,7 +84,7 @@ public class SseServiceImpl implements SseService { | @@ -83,7 +84,7 @@ public class SseServiceImpl implements SseService { | ||
| 83 | } | 84 | } |
| 84 | 85 | ||
| 85 | @Override | 86 | @Override |
| 86 | - public ChatResponse sseChat(Boolean isHaveData,Integer user_id,String uid, ChatRequest chatRequest, ChatCompletion.Model model, CompleteCallback completeCallback) { | 87 | + public ChatResponse sseChat(Boolean isHaveData, Integer user_id, String uid, ChatRequest chatRequest, ChatCompletion.Model model, CompleteCallback completeCallback, HttpServletRequest request) { |
| 87 | if (ArrayUtil.isEmpty(chatRequest.getMsg())) { | 88 | if (ArrayUtil.isEmpty(chatRequest.getMsg())) { |
| 88 | log.info("参数异常,msg为null", uid); | 89 | log.info("参数异常,msg为null", uid); |
| 89 | throw new BaseException("参数异常,msg不能为空~"); | 90 | throw new BaseException("参数异常,msg不能为空~"); |
| @@ -130,7 +131,7 @@ public class SseServiceImpl implements SseService { | @@ -130,7 +131,7 @@ public class SseServiceImpl implements SseService { | ||
| 130 | recordId = completeCallback.recordSseChat(user_id,chatRequest,completion); | 131 | recordId = completeCallback.recordSseChat(user_id,chatRequest,completion); |
| 131 | } | 132 | } |
| 132 | 133 | ||
| 133 | - OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener(sseEmitter,completeCallback,recordId); | 134 | + OpenAISSEEventSourceListener openAIEventSourceListener = new OpenAISSEEventSourceListener(sseEmitter,completeCallback,recordId,request); |
| 134 | openAIEventSourceListener.setHaveData(isHaveData); | 135 | openAIEventSourceListener.setHaveData(isHaveData); |
| 135 | openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); | 136 | openAiStreamClient.streamChatCompletion(completion, openAIEventSourceListener); |
| 136 | // LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT); | 137 | // LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT); |
lh-modules/lh-login/src/main/java/com/zhonglai/luhui/login/service/LocalLoginService.java
0 → 100644
| 1 | +package com.zhonglai.luhui.login.service; | ||
| 2 | + | ||
| 3 | +import com.ruoyi.common.constant.Constants; | ||
| 4 | +import com.ruoyi.common.tool.SysLogininforType; | ||
| 5 | +import com.ruoyi.common.utils.MessageUtils; | ||
| 6 | +import com.ruoyi.common.utils.spring.SpringUtils; | ||
| 7 | +import com.zhonglai.luhui.security.dto.LoginToken; | ||
| 8 | +import com.zhonglai.luhui.security.dto.OpenAiLoginUser; | ||
| 9 | +import com.zhonglai.luhui.security.dto.OpenAiUserInfo; | ||
| 10 | +import com.zhonglai.luhui.sys.manager.AsyncManager; | ||
| 11 | +import com.zhonglai.luhui.sys.manager.factory.AsyncFactory; | ||
| 12 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 13 | +import org.springframework.security.core.Authentication; | ||
| 14 | +import org.springframework.stereotype.Service; | ||
| 15 | + | ||
| 16 | +@Service | ||
| 17 | +public class LocalLoginService { | ||
| 18 | + @Autowired | ||
| 19 | + private LoginService loginService; | ||
| 20 | + | ||
| 21 | + public String openaiLoginByPass(String user,String pass) { | ||
| 22 | + // 用户验证 | ||
| 23 | + Authentication authentication = loginService.userPasswordVerification(user,pass, SpringUtils.getBean("openAiConfigurerAdapter")); | ||
| 24 | + AsyncManager.me().execute(AsyncFactory.recordLogininfor(user, SysLogininforType.openAi, Constants.LOGIN_SUCCESS, MessageUtils.message("user.login.success"))); | ||
| 25 | + OpenAiLoginUser loginUser = (OpenAiLoginUser) authentication.getPrincipal(); | ||
| 26 | + loginUser.setSysLogininforType(SysLogininforType.openAi); | ||
| 27 | + return createToken(loginUser); | ||
| 28 | + } | ||
| 29 | + | ||
| 30 | + public String createToken(OpenAiLoginUser loginUser) | ||
| 31 | + { | ||
| 32 | + OpenAiUserInfo openAiUserInfo = loginUser.getOpenAiUserInfo(); | ||
| 33 | + LoginToken loginToken = new LoginToken(openAiUserInfo.getId(),openAiUserInfo.getPhone(),openAiUserInfo.getNickname(),loginUser.getSysLogininforType().name()); | ||
| 34 | + return loginToken.get(); | ||
| 35 | + } | ||
| 36 | +} |
| @@ -151,7 +151,7 @@ public class LoginService { | @@ -151,7 +151,7 @@ public class LoginService { | ||
| 151 | } | 151 | } |
| 152 | } | 152 | } |
| 153 | 153 | ||
| 154 | - private Authentication userPasswordVerification(String username, String password, DefaultSecurityConfig defaultSecurityConfig) | 154 | + public Authentication userPasswordVerification(String username, String password, DefaultSecurityConfig defaultSecurityConfig) |
| 155 | { | 155 | { |
| 156 | Authentication authentication = null; | 156 | Authentication authentication = null; |
| 157 | try | 157 | try |
| 1 | package com.zhonglai.luhui.openai; | 1 | package com.zhonglai.luhui.openai; |
| 2 | 2 | ||
| 3 | import com.ruoyi.common.utils.StringUtils; | 3 | import com.ruoyi.common.utils.StringUtils; |
| 4 | +import com.ruoyi.framework.config.ResourcesConfig; | ||
| 5 | +import com.zhonglai.luhui.security.filter.JwtAuthenticationTokenFilter; | ||
| 4 | import okhttp3.OkHttpClient; | 6 | import okhttp3.OkHttpClient; |
| 5 | import org.apache.tomcat.util.http.LegacyCookieProcessor; | 7 | import org.apache.tomcat.util.http.LegacyCookieProcessor; |
| 6 | import org.springframework.boot.SpringApplication; | 8 | import org.springframework.boot.SpringApplication; |
| @@ -11,6 +13,7 @@ import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactor | @@ -11,6 +13,7 @@ import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactor | ||
| 11 | import org.springframework.boot.web.server.WebServerFactoryCustomizer; | 13 | import org.springframework.boot.web.server.WebServerFactoryCustomizer; |
| 12 | import org.springframework.context.annotation.Bean; | 14 | import org.springframework.context.annotation.Bean; |
| 13 | import org.springframework.context.annotation.ComponentScan; | 15 | import org.springframework.context.annotation.ComponentScan; |
| 16 | +import org.springframework.context.annotation.FilterType; | ||
| 14 | 17 | ||
| 15 | import java.io.UnsupportedEncodingException; | 18 | import java.io.UnsupportedEncodingException; |
| 16 | import java.net.URL; | 19 | import java.net.URL; |
| 1 | +package com.zhonglai.luhui.openai.config; | ||
| 2 | + | ||
| 3 | +import com.ruoyi.common.tool.SysLogininforType; | ||
| 4 | +import com.ruoyi.common.utils.DateUtils; | ||
| 5 | +import com.ruoyi.common.utils.StringUtils; | ||
| 6 | +import com.zhonglai.luhui.dao.service.PublicService; | ||
| 7 | +import com.zhonglai.luhui.security.dto.BaseLoginUser; | ||
| 8 | +import com.zhonglai.luhui.security.dto.LoginToken; | ||
| 9 | +import com.zhonglai.luhui.security.dto.OpenAiLoginUser; | ||
| 10 | +import com.zhonglai.luhui.security.dto.OpenAiUserInfo; | ||
| 11 | +import com.zhonglai.luhui.security.filter.JwtAuthenticationTokenFilter; | ||
| 12 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 13 | +import org.springframework.beans.factory.annotation.Value; | ||
| 14 | +import org.springframework.context.annotation.Configuration; | ||
| 15 | +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; | ||
| 16 | +import org.springframework.stereotype.Component; | ||
| 17 | + | ||
| 18 | +import javax.servlet.http.HttpServletRequest; | ||
| 19 | + | ||
| 20 | +@Component | ||
| 21 | +public class JwtAuthenticationTokenFilterImpl extends JwtAuthenticationTokenFilter { | ||
| 22 | + @Value("${token.header}") | ||
| 23 | + private String header; | ||
| 24 | + @Value("${token.expireTime}") | ||
| 25 | + private Long expireTime; | ||
| 26 | + | ||
| 27 | + @Autowired | ||
| 28 | + private PublicService publicService; | ||
| 29 | + | ||
| 30 | + @Override | ||
| 31 | + public BaseLoginUser getBaseLoginUser(HttpServletRequest request) { | ||
| 32 | + String token = request.getHeader(header); | ||
| 33 | + if (StringUtils.isNotEmpty(token)) | ||
| 34 | + { | ||
| 35 | + try { | ||
| 36 | + LoginToken loginToken = new LoginToken(token); | ||
| 37 | + BaseLoginUser baseLoginUser = to(loginToken); | ||
| 38 | + baseLoginUser.setExpireTime(loginToken.getCreateTime()+expireTime); | ||
| 39 | + return baseLoginUser; | ||
| 40 | + }catch (Exception e) | ||
| 41 | + { | ||
| 42 | + logger.error("token验证失败",e); | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + return null; | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + return null; | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + @Override | ||
| 52 | + public boolean verifyToken(BaseLoginUser baseLoginUser) { | ||
| 53 | + if(null != baseLoginUser) | ||
| 54 | + { | ||
| 55 | + Integer currentTime = DateUtils.getNowTimeMilly(); | ||
| 56 | + Long expireTime = baseLoginUser.getExpireTime(); | ||
| 57 | + if (expireTime - currentTime > 0) | ||
| 58 | + { | ||
| 59 | + return true; | ||
| 60 | + } | ||
| 61 | + } | ||
| 62 | + return false; | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + @Override | ||
| 66 | + public UsernamePasswordAuthenticationToken getUsernamePasswordAuthenticationToken(BaseLoginUser loginUser) { | ||
| 67 | + return new UsernamePasswordAuthenticationToken(loginUser, null, null); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | + public BaseLoginUser to(LoginToken loginToken) { | ||
| 71 | + OpenAiUserInfo openAiUserInfo = publicService.getObjectForTableName(OpenAiUserInfo.class,"id", loginToken.getUserId()+"","`lk_openai`.`user_info`"); | ||
| 72 | + OpenAiLoginUser openAiLoginUser = new OpenAiLoginUser(); | ||
| 73 | + openAiLoginUser.setUserId(Long.parseLong(openAiUserInfo.getId()+"")); | ||
| 74 | + openAiLoginUser.setOpenAiUserInfo(openAiUserInfo); | ||
| 75 | + openAiLoginUser.setSysLogininforType(SysLogininforType.openAi); | ||
| 76 | + return openAiLoginUser; | ||
| 77 | + } | ||
| 78 | +} |
| @@ -31,6 +31,7 @@ public class OpenAiConfigurerAdapter extends DefaultSecurityConfig { | @@ -31,6 +31,7 @@ public class OpenAiConfigurerAdapter extends DefaultSecurityConfig { | ||
| 31 | @Override | 31 | @Override |
| 32 | public void configHttpSecurity(HttpSecurity httpSecurity) throws Exception { | 32 | public void configHttpSecurity(HttpSecurity httpSecurity) throws Exception { |
| 33 | securityConfigService.configHttpSecurity(httpSecurity); | 33 | securityConfigService.configHttpSecurity(httpSecurity); |
| 34 | + | ||
| 34 | } | 35 | } |
| 35 | 36 | ||
| 36 | @Override | 37 | @Override |
| @@ -12,6 +12,7 @@ import com.ruoyi.common.utils.GsonConstructor; | @@ -12,6 +12,7 @@ import com.ruoyi.common.utils.GsonConstructor; | ||
| 12 | import com.ruoyi.common.utils.StringUtils; | 12 | import com.ruoyi.common.utils.StringUtils; |
| 13 | import com.zhonglai.luhui.action.BaseController; | 13 | import com.zhonglai.luhui.action.BaseController; |
| 14 | import com.zhonglai.luhui.dao.service.PublicService; | 14 | import com.zhonglai.luhui.dao.service.PublicService; |
| 15 | +import com.zhonglai.luhui.login.service.LocalLoginService; | ||
| 15 | import com.zhonglai.luhui.login.service.LoginService; | 16 | import com.zhonglai.luhui.login.service.LoginService; |
| 16 | import com.zhonglai.luhui.security.dto.OpenAiUserInfo; | 17 | import com.zhonglai.luhui.security.dto.OpenAiUserInfo; |
| 17 | import io.swagger.annotations.Api; | 18 | import io.swagger.annotations.Api; |
| @@ -34,7 +35,7 @@ public class OpenAiUserLoginController extends BaseController { | @@ -34,7 +35,7 @@ public class OpenAiUserLoginController extends BaseController { | ||
| 34 | 35 | ||
| 35 | public static String ENCODE_KEY = "com/zhonglai"; | 36 | public static String ENCODE_KEY = "com/zhonglai"; |
| 36 | @Autowired | 37 | @Autowired |
| 37 | - private LoginService loginService; | 38 | + private LocalLoginService loginService; |
| 38 | 39 | ||
| 39 | @Autowired | 40 | @Autowired |
| 40 | private PublicService publicService; | 41 | private PublicService publicService; |
| @@ -2,17 +2,16 @@ package com.zhonglai.luhui.openai.controller; | @@ -2,17 +2,16 @@ package com.zhonglai.luhui.openai.controller; | ||
| 2 | 2 | ||
| 3 | import com.ruoyi.common.core.domain.AjaxResult; | 3 | import com.ruoyi.common.core.domain.AjaxResult; |
| 4 | import com.ruoyi.common.core.page.TableDataInfo; | 4 | import com.ruoyi.common.core.page.TableDataInfo; |
| 5 | +import com.ruoyi.common.utils.DateUtils; | ||
| 5 | import com.zhonglai.luhui.action.BaseController; | 6 | import com.zhonglai.luhui.action.BaseController; |
| 6 | import com.zhonglai.luhui.dao.service.PublicService; | 7 | import com.zhonglai.luhui.dao.service.PublicService; |
| 8 | +import com.zhonglai.luhui.openai.dto.UserRoom; | ||
| 7 | import com.zhonglai.luhui.security.dto.OpenAiUserInfo; | 9 | import com.zhonglai.luhui.security.dto.OpenAiUserInfo; |
| 8 | import com.zhonglai.luhui.security.utils.SecurityUtils; | 10 | import com.zhonglai.luhui.security.utils.SecurityUtils; |
| 9 | import io.swagger.annotations.Api; | 11 | import io.swagger.annotations.Api; |
| 10 | import io.swagger.annotations.ApiOperation; | 12 | import io.swagger.annotations.ApiOperation; |
| 11 | import org.springframework.beans.factory.annotation.Autowired; | 13 | import org.springframework.beans.factory.annotation.Autowired; |
| 12 | -import org.springframework.web.bind.annotation.GetMapping; | ||
| 13 | -import org.springframework.web.bind.annotation.PostMapping; | ||
| 14 | -import org.springframework.web.bind.annotation.RequestMapping; | ||
| 15 | -import org.springframework.web.bind.annotation.RestController; | 14 | +import org.springframework.web.bind.annotation.*; |
| 16 | 15 | ||
| 17 | import java.util.List; | 16 | import java.util.List; |
| 18 | import java.util.Map; | 17 | import java.util.Map; |
| @@ -29,15 +28,29 @@ public class UserRoomController extends BaseController { | @@ -29,15 +28,29 @@ public class UserRoomController extends BaseController { | ||
| 29 | public TableDataInfo getUserRoomList() | 28 | public TableDataInfo getUserRoomList() |
| 30 | { | 29 | { |
| 31 | startPage(); | 30 | startPage(); |
| 32 | - List<Map<String,Object>> list=publicService.getObjectListBySQL("SELECT * FROM `user_room` WHERE user_id="+SecurityUtils.getUserId().toString()+" AND is_delete=1"); | 31 | + List<Map<String,Object>> list=publicService.getObjectListBySQL("SELECT * FROM `lk_openai`.`user_room` WHERE user_id="+SecurityUtils.getUserId().toString()+" AND is_delete=1"); |
| 33 | return getDataTable(list); | 32 | return getDataTable(list); |
| 34 | } | 33 | } |
| 35 | 34 | ||
| 36 | -// @ApiOperation("添加话提") | ||
| 37 | -// @PostMapping("/addUserRoom") | ||
| 38 | -// public AjaxResult addUserRoom() | ||
| 39 | -// { | ||
| 40 | -// | ||
| 41 | -// } | 35 | + @ApiOperation("添加话提") |
| 36 | + @PostMapping("/addUserRoom") | ||
| 37 | + public AjaxResult addUserRoom(String roomTitle) | ||
| 38 | + { | ||
| 39 | + UserRoom userRoom = new UserRoom(); | ||
| 40 | + userRoom.setUser_id(SecurityUtils.getUserId().intValue()); | ||
| 41 | + userRoom.setTitle(roomTitle); | ||
| 42 | + userRoom.setIs_delete(1); | ||
| 43 | + userRoom.setCreate_time(DateUtils.getNowTimeMilly()); | ||
| 44 | + int i = publicService.insertToTable(userRoom,"`lk_openai`.`user_room`"); | ||
| 45 | + return toAjax(i).put("roomId",userRoom.getId()); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + @ApiOperation("删除话提") | ||
| 49 | + @DeleteMapping("/del/{id}") | ||
| 50 | + public AjaxResult del(@PathVariable Integer id) | ||
| 51 | + { | ||
| 52 | + return toAjax(publicService.updateBySql("UPDATE `lk_openai`.`user_room` SET is_delete=0 WHERE id="+id+" and user_id="+SecurityUtils.getUserId())); | ||
| 53 | + } | ||
| 54 | + | ||
| 42 | 55 | ||
| 43 | } | 56 | } |
| 1 | +package com.zhonglai.luhui.openai.dto; | ||
| 2 | + | ||
| 3 | +import lombok.Data; | ||
| 4 | + | ||
| 5 | +@Data | ||
| 6 | +public class UserRoom { | ||
| 7 | + private Integer id; // int(11) NOT NULL AUTO_INCREMENT COMMENT '主键', | ||
| 8 | + private String title; // varchar(150) DEFAULT NULL COMMENT '标题', | ||
| 9 | + private Integer user_id; // int(11) DEFAULT NULL COMMENT '用户', | ||
| 10 | + private Integer create_time; // int(11) DEFAULT NULL COMMENT '创建时间', | ||
| 11 | + private Integer is_delete; // int(11) DEFAULT NULL COMMENT '删除', | ||
| 12 | +} |
| 1 | -# 项目相关配置 jhlt: # 名称 name: zhonglai # 版本 version: 3.8.2 # 版权年份 copyrightYear: 2022 # 获取ip地址开关 addressEnabled: false profile: /www/wwwroot/lh-openai # 开发环境配置 server: # 服务器的HTTP端口,默认为8080 port: 8082 servlet: # 应用的访问路径 context-path: / tomcat: # tomcat的URI编码 uri-encoding: UTF-8 # 连接数满后的排队数,默认为100 accept-count: 1000 threads: # tomcat最大线程数,默认为200 max: 800 # Tomcat启动初始化的线程数,默认值10 min-spare: 100 # 日志配置 logging: level: com.ruoyi: debug org.springframework: warn # Spring配置 spring: # 资源信息 messages: # 国际化资源文件路径 basename: i18n/messages profiles: active: druid # 文件上传 servlet: multipart: # 单个文件大小 max-file-size: 10MB # 设置总上传的文件大小 max-request-size: 20MB # 服务模块 devtools: restart: # 热部署开关 enabled: true # redis 配置 redis: # 地址 host: 47.112.163.61 # 端口,默认为6379 port: 9527 # 数据库索引 database: 1 # 密码 password: Luhui586 # 连接超时时间 timeout: 10s lettuce: pool: # 连接池中的最小空闲连接 min-idle: 0 # 连接池中的最大空闲连接 max-idle: 8 # 连接池的最大数据库连接数 max-active: 8 # #连接池最大阻塞等待时间(使用负值表示没有限制) max-wait: -1ms # web: # resources: # static-locations: classpath:/static/, classpath:/templates/ # token配置 token: # 令牌自定义标识 header: Authorization # 令牌密钥 secret: abcdefghijklmnopqrstuvwxyz # 令牌有效期(默认30分钟) expireTime: 1440 rediskey: lh-openai # MyBatis配置 mybatis: # 搜索指定包别名 typeAliasesPackage: com.ruoyi.**.domain # 配置mapper的扫描,找到所有的mapper.xml映射文件 mapperLocations: classpath*:mapper/**/*Mapper.xml # 加载全局的配置文件 configLocation: classpath:mybatis/mybatis-config.xml # PageHelper分页插件 pagehelper: helperDialect: mysql supportMethodsArguments: true params: count=countSql # Swagger配置 swagger: # 是否开启swagger enabled: true # 请求前缀 pathMapping: /dev-api # 防止XSS攻击 xss: # 过滤开关 enabled: true # 排除链接(多个用逗号分隔) excludes: /system/notice # 匹配链接 urlPatterns: /system/*,/monitor/*,/tool/* sys: ## // 对于登录login 注册register 验证码captchaImage 允许匿名访问 antMatchers: /login,/register,/captchaImage,/getCacheObject,/v2/api-docs,/openAiUserLogin/*,/chatGPTStream/upUserFlowPacketRemain,/createSse,/chat,/closeSse chatgpt: token: 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 timeout: 5000 apiHost: https://api.openai.com/ proxy: isProxy: true host: 127.0.0.1 port: 7890 | ||
| 1 | +# 项目相关配置 jhlt: # 名称 name: zhonglai # 版本 version: 3.8.2 # 版权年份 copyrightYear: 2022 # 获取ip地址开关 addressEnabled: false profile: /www/wwwroot/lh-openai # 开发环境配置 server: # 服务器的HTTP端口,默认为8080 port: 8082 servlet: # 应用的访问路径 context-path: / tomcat: # tomcat的URI编码 uri-encoding: UTF-8 # 连接数满后的排队数,默认为100 accept-count: 1000 threads: # tomcat最大线程数,默认为200 max: 800 # Tomcat启动初始化的线程数,默认值10 min-spare: 100 # 日志配置 logging: level: com.ruoyi: debug org.springframework: warn # Spring配置 spring: # 资源信息 messages: # 国际化资源文件路径 basename: i18n/messages profiles: active: druid # 文件上传 servlet: multipart: # 单个文件大小 max-file-size: 10MB # 设置总上传的文件大小 max-request-size: 20MB # 服务模块 devtools: restart: # 热部署开关 enabled: true # redis 配置 redis: # 地址 host: 47.112.163.61 # 端口,默认为6379 port: 9527 # 数据库索引 database: 1 # 密码 password: Luhui586 # 连接超时时间 timeout: 10s lettuce: pool: # 连接池中的最小空闲连接 min-idle: 0 # 连接池中的最大空闲连接 max-idle: 8 # 连接池的最大数据库连接数 max-active: 8 # #连接池最大阻塞等待时间(使用负值表示没有限制) max-wait: -1ms # web: # resources: # static-locations: classpath:/static/, classpath:/templates/ # token配置 token: # 令牌自定义标识 header: Authorization # 令牌密钥 secret: abcdefghijklmnopqrstuvwxyz # 令牌有效期(默认30分钟) expireTime: 31536000 rediskey: lh-openai # MyBatis配置 mybatis: # 搜索指定包别名 typeAliasesPackage: com.ruoyi.**.domain # 配置mapper的扫描,找到所有的mapper.xml映射文件 mapperLocations: classpath*:mapper/**/*Mapper.xml # 加载全局的配置文件 configLocation: classpath:mybatis/mybatis-config.xml # PageHelper分页插件 pagehelper: helperDialect: mysql supportMethodsArguments: true params: count=countSql # Swagger配置 swagger: # 是否开启swagger enabled: true # 请求前缀 pathMapping: /dev-api # 防止XSS攻击 xss: # 过滤开关 enabled: true # 排除链接(多个用逗号分隔) excludes: /system/notice # 匹配链接 urlPatterns: /system/*,/monitor/*,/tool/* sys: ## // 对于登录login 注册register 验证码captchaImage 允许匿名访问 antMatchers: /login,/register,/captchaImage,/getCacheObject,/v2/api-docs,/openAiUserLogin/*,/chatGPTStream/upUserFlowPacketRemain,/createSse,/chat,/closeSse chatgpt: token: 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 timeout: 5000 apiHost: https://api.openai.com/ proxy: isProxy: true host: 127.0.0.1 port: 7890 |
| 1 | package com.zhonglai.luhui.smart.feeder; | 1 | package com.zhonglai.luhui.smart.feeder; |
| 2 | 2 | ||
| 3 | import com.zhonglai.luhui.smart.feeder.config.OpenCVConfig; | 3 | import com.zhonglai.luhui.smart.feeder.config.OpenCVConfig; |
| 4 | +import com.zhonglai.luhui.smart.feeder.service.OpenCVService; | ||
| 4 | import org.springframework.boot.SpringApplication; | 5 | import org.springframework.boot.SpringApplication; |
| 5 | import org.springframework.boot.autoconfigure.SpringBootApplication; | 6 | import org.springframework.boot.autoconfigure.SpringBootApplication; |
| 6 | 7 |
| 1 | +package com.zhonglai.luhui.smart.feeder.ai; | ||
| 2 | + | ||
| 3 | +import com.zhonglai.luhui.smart.feeder.config.OpenCVConfig; | ||
| 4 | +import org.opencv.core.*; | ||
| 5 | +import org.opencv.imgcodecs.Imgcodecs; | ||
| 6 | +import org.opencv.imgproc.Imgproc; | ||
| 7 | + | ||
| 8 | +import java.util.ArrayList; | ||
| 9 | +import java.util.List; | ||
| 10 | + | ||
| 11 | +public class AdaptiveThresholdTuning { | ||
| 12 | + | ||
| 13 | + public static void main(String[] args) { | ||
| 14 | + OpenCVConfig.loadOpenCv(args); | ||
| 15 | + // 加载标准水面区域图片 | ||
| 16 | + List<Mat> standardImages = loadStandardImages(); | ||
| 17 | + | ||
| 18 | + // 初始化最佳参数和最佳分割结果 | ||
| 19 | + int bestBlockSize = -1; | ||
| 20 | + double bestC = -1; | ||
| 21 | + double bestAccuracy = 0; | ||
| 22 | + | ||
| 23 | + // 遍历不同的参数组合进行测试 | ||
| 24 | + for (int blockSize = 3; blockSize <= 21; blockSize += 2) { | ||
| 25 | + for (double c = -20; c <= 20; c += 2) { | ||
| 26 | + double totalAccuracy = 0; | ||
| 27 | + | ||
| 28 | + // 在标准图片上进行测试并计算准确度 | ||
| 29 | + for (Mat standardImage : standardImages) { | ||
| 30 | + double accuracy = testAdaptiveThreshold(standardImage, blockSize, c); | ||
| 31 | + totalAccuracy += accuracy; | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + // 计算平均准确度 | ||
| 35 | + double averageAccuracy = totalAccuracy / standardImages.size(); | ||
| 36 | + | ||
| 37 | + // 更新最佳参数和最佳准确度 | ||
| 38 | + if (averageAccuracy > bestAccuracy) { | ||
| 39 | + bestAccuracy = averageAccuracy; | ||
| 40 | + bestBlockSize = blockSize; | ||
| 41 | + bestC = c; | ||
| 42 | + } | ||
| 43 | + } | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + // 输出最佳参数 | ||
| 47 | + System.out.println("Best block size: " + bestBlockSize); | ||
| 48 | + System.out.println("Best C: " + bestC); | ||
| 49 | + } | ||
| 50 | + | ||
| 51 | + // 加载标准水面区域图片 | ||
| 52 | + private static List<Mat> loadStandardImages() { | ||
| 53 | + List<Mat> standardImages = new ArrayList<>(); | ||
| 54 | + standardImages.add(Imgcodecs.imread("C:\\Users\\123\\Pictures\\2\\0.jpg", Imgcodecs.IMREAD_GRAYSCALE)); | ||
| 55 | + standardImages.add(Imgcodecs.imread("C:\\Users\\123\\Pictures\\2\\1.jpg", Imgcodecs.IMREAD_GRAYSCALE)); | ||
| 56 | + standardImages.add(Imgcodecs.imread("C:\\Users\\123\\Pictures\\2\\2.jpg", Imgcodecs.IMREAD_GRAYSCALE)); | ||
| 57 | + // ... 加载更多的标准图片 | ||
| 58 | + | ||
| 59 | + return standardImages; | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + // 测试适应性阈值分割并返回准确度 | ||
| 63 | + private static double testAdaptiveThreshold(Mat standardImage, int blockSize, double c) { | ||
| 64 | + // 创建用于分割结果的图像 | ||
| 65 | + Mat segmentedImg = new Mat(); | ||
| 66 | + | ||
| 67 | + // 应用适应性阈值分割 | ||
| 68 | + Imgproc.adaptiveThreshold(standardImage, segmentedImg, 255, Imgproc.ADAPTIVE_THRESH_MEAN_C, Imgproc.THRESH_BINARY_INV, blockSize, c); | ||
| 69 | + | ||
| 70 | + // 计算准确度 | ||
| 71 | + double accuracy = calculateAccuracy(segmentedImg, standardImage); | ||
| 72 | + | ||
| 73 | + return accuracy; | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + // 计算分割结果与标准图片的准确度 | ||
| 77 | + private static double calculateAccuracy(Mat segmentedImg, Mat standardImage) { | ||
| 78 | + // 调整分割结果图片的大小为标准图片的大小 | ||
| 79 | + Mat resizedSegmentedImg = new Mat(); | ||
| 80 | + Imgproc.resize(segmentedImg, resizedSegmentedImg, standardImage.size()); | ||
| 81 | + | ||
| 82 | + int totalPixels = standardImage.cols() * standardImage.rows(); // 图像总像素数 | ||
| 83 | + int matchedPixels = 0; // 匹配的像素数 | ||
| 84 | + | ||
| 85 | + for (int i = 0; i < standardImage.rows(); i++) { | ||
| 86 | + for (int j = 0; j < standardImage.cols(); j++) { | ||
| 87 | + double segmentedPixel = resizedSegmentedImg.get(i, j)[0]; | ||
| 88 | + double standardPixel = standardImage.get(i, j)[0]; | ||
| 89 | + | ||
| 90 | + // 获取透明度通道 | ||
| 91 | + double segmentedOpacity = resizedSegmentedImg.get(i, j)[0]; // 假设透明度在分割结果图片的第二通道 | ||
| 92 | + | ||
| 93 | + // 判断是否匹配像素 | ||
| 94 | + boolean isMatched = false; | ||
| 95 | + | ||
| 96 | + if (Math.abs(segmentedPixel - standardPixel) <= 10) { // 考虑反光的偏差范围 | ||
| 97 | + // 判断透明度准确度 | ||
| 98 | + double standardOpacity = standardImage.get(i, j)[0]; // 假设透明度在标准图片的第二通道 | ||
| 99 | + if (Math.abs(segmentedOpacity - standardOpacity) <= 10) { | ||
| 100 | + isMatched = true; | ||
| 101 | + } | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + if (isMatched) { | ||
| 105 | + matchedPixels++; | ||
| 106 | + } | ||
| 107 | + } | ||
| 108 | + } | ||
| 109 | + | ||
| 110 | + double accuracy = (double) matchedPixels / totalPixels; | ||
| 111 | + return accuracy; | ||
| 112 | + } | ||
| 113 | + | ||
| 114 | + | ||
| 115 | +} |
| 1 | +package com.zhonglai.luhui.smart.feeder.ai; | ||
| 2 | + | ||
| 3 | +import org.opencv.core.Core; | ||
| 4 | +import org.opencv.core.Mat; | ||
| 5 | +import org.opencv.core.Scalar; | ||
| 6 | +import org.opencv.imgproc.Imgproc; | ||
| 7 | + | ||
| 8 | +/** | ||
| 9 | + * 特征提取:使用适当的图像处理技术提取透明度、亮度和反光特性。 | ||
| 10 | + * 例如,可以使用像素值来计算透明度,使用图像直方图、颜色空间转换或滤波器来计算亮度,使用反光度量算法来计算反光特性。 | ||
| 11 | + */ | ||
| 12 | +public class ImgFeatureExtractionUtil { | ||
| 13 | + /** | ||
| 14 | + * 透明度特征提取 | ||
| 15 | + * @param image | ||
| 16 | + * @return | ||
| 17 | + */ | ||
| 18 | + public static double extractTransparencyFeature(Mat image) { | ||
| 19 | + // 提取透明度特征 | ||
| 20 | + Mat alphaChannel = new Mat(); | ||
| 21 | + Core.extractChannel(image, alphaChannel, 3); | ||
| 22 | + | ||
| 23 | + Scalar mean = Core.mean(alphaChannel); | ||
| 24 | + return mean.val[0]; | ||
| 25 | + } | ||
| 26 | + | ||
| 27 | + /** | ||
| 28 | + * 亮度特征提取 | ||
| 29 | + * @param image | ||
| 30 | + * @return | ||
| 31 | + */ | ||
| 32 | + public static double extractBrightnessFeature(Mat image) { | ||
| 33 | + // 转换为灰度图像 | ||
| 34 | + Mat grayImage = new Mat(); | ||
| 35 | + Imgproc.cvtColor(image, grayImage, Imgproc.COLOR_BGR2GRAY); | ||
| 36 | + | ||
| 37 | + Scalar mean = Core.mean(grayImage); | ||
| 38 | + return mean.val[0]; | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + /** | ||
| 42 | + * 反光特征提取 | ||
| 43 | + * @param image | ||
| 44 | + * @return | ||
| 45 | + */ | ||
| 46 | + public static double extractReflectionFeature(Mat image) { | ||
| 47 | + // 转换为灰度图像 | ||
| 48 | + Mat grayImage = new Mat(); | ||
| 49 | + Imgproc.cvtColor(image, grayImage, Imgproc.COLOR_BGR2GRAY); | ||
| 50 | + | ||
| 51 | + // 计算反光特征 | ||
| 52 | + Mat binaryImage = new Mat(); | ||
| 53 | + double thresholdValue = 100; // 调整阈值以根据图像适应性地检测反光 | ||
| 54 | + double maxValue = 255; | ||
| 55 | + Imgproc.threshold(grayImage, binaryImage, thresholdValue, maxValue, Imgproc.THRESH_BINARY); | ||
| 56 | + | ||
| 57 | + return Core.countNonZero(binaryImage); | ||
| 58 | + } | ||
| 59 | +} |
lh-modules/lh-smart-feeder/src/main/java/com/zhonglai/luhui/smart/feeder/draw/FishRegionPanel.java
0 → 100644
| 1 | +package com.zhonglai.luhui.smart.feeder.draw; | ||
| 2 | + | ||
| 3 | +import com.zhonglai.luhui.smart.feeder.util.OpenCVUtils; | ||
| 4 | +import org.opencv.core.Mat; | ||
| 5 | + | ||
| 6 | +import javax.swing.*; | ||
| 7 | +import java.awt.*; | ||
| 8 | +import java.awt.image.BufferedImage; | ||
| 9 | +import java.util.ArrayList; | ||
| 10 | +import java.util.List; | ||
| 11 | + | ||
| 12 | +public class FishRegionPanel | ||
| 13 | +{ | ||
| 14 | + private JFrame frame; | ||
| 15 | + private JLabel lblImage; | ||
| 16 | + private JLabel srcImage; | ||
| 17 | + private GraphPanel pnlGraph; | ||
| 18 | + | ||
| 19 | + private void init() { | ||
| 20 | + frame = new JFrame("Fish Detection"); | ||
| 21 | + frame.setSize(1000,1000); | ||
| 22 | + frame.setLayout(null); | ||
| 23 | + | ||
| 24 | + lblImage = new JLabel(); | ||
| 25 | + srcImage = new JLabel(); | ||
| 26 | + pnlGraph = new GraphPanel(new ArrayList<>()); | ||
| 27 | + | ||
| 28 | + | ||
| 29 | + frame.setLayout(new BorderLayout()); | ||
| 30 | + frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); | ||
| 31 | + | ||
| 32 | + frame.add(lblImage, BorderLayout.NORTH); | ||
| 33 | + frame.add(srcImage, BorderLayout.SOUTH); | ||
| 34 | + frame.add(pnlGraph, BorderLayout.CENTER); | ||
| 35 | + | ||
| 36 | +// frame.pack(); | ||
| 37 | + frame.setVisible(true); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + | ||
| 41 | + public void end() | ||
| 42 | + { | ||
| 43 | + frame.setVisible(true); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + public FishRegionPanel() { | ||
| 47 | + init(); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + public void addFishCount(int size) { | ||
| 51 | + pnlGraph.getFishCountList().add(size); | ||
| 52 | + pnlGraph.repaint(); | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + public void displayImage(Mat image) { | ||
| 56 | + lblImage.setIcon(new ImageIcon(convertMatToImage(image))); | ||
| 57 | + frame.repaint(); | ||
| 58 | + } | ||
| 59 | + | ||
| 60 | + public void dispSrcImage(Mat image) { | ||
| 61 | + srcImage.setIcon(new ImageIcon(convertMatToImage(image))); | ||
| 62 | + frame.repaint(); | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + private Image convertMatToImage(Mat mat) { | ||
| 66 | + BufferedImage bufferedImage = OpenCVUtils.matToBufferedImage(mat); | ||
| 67 | + return bufferedImage.getScaledInstance(300, 300, Image.SCALE_SMOOTH); | ||
| 68 | + } | ||
| 69 | + | ||
| 70 | +} |
lh-modules/lh-smart-feeder/src/main/java/com/zhonglai/luhui/smart/feeder/draw/GraphPanel.java
0 → 100644
| 1 | +package com.zhonglai.luhui.smart.feeder.draw; | ||
| 2 | + | ||
| 3 | +import javax.swing.*; | ||
| 4 | +import java.awt.*; | ||
| 5 | +import java.util.List; | ||
| 6 | + | ||
| 7 | +public class GraphPanel extends JPanel { | ||
| 8 | + private List<Integer> fishCountList; | ||
| 9 | + | ||
| 10 | + Dimension preferredSize = new Dimension(300, 300); | ||
| 11 | + | ||
| 12 | + public GraphPanel(List<Integer> fishCountList) { | ||
| 13 | + this.fishCountList = fishCountList; | ||
| 14 | + } | ||
| 15 | + public List<Integer> getFishCountList() { | ||
| 16 | + return this.fishCountList; | ||
| 17 | + } | ||
| 18 | + @Override | ||
| 19 | + protected void paintComponent(Graphics g) { | ||
| 20 | + super.paintComponent(g); | ||
| 21 | + | ||
| 22 | + Graphics2D g2d = (Graphics2D) g; | ||
| 23 | + g2d.setColor(Color.BLUE); | ||
| 24 | + | ||
| 25 | +// g2d.setStroke(new BasicStroke(2)); | ||
| 26 | + if (!fishCountList.isEmpty()) { | ||
| 27 | + int panelWidth = 300; | ||
| 28 | + int panelHeight = 300; | ||
| 29 | + int countSize = fishCountList.size(); | ||
| 30 | + int[] counts = new int[countSize]; | ||
| 31 | + | ||
| 32 | + double max = 0; | ||
| 33 | + for (int i = 0; i < countSize; i++) { | ||
| 34 | + counts[i] = fishCountList.get(i); | ||
| 35 | + if(counts[i]>max) | ||
| 36 | + { | ||
| 37 | + max = counts[i]; | ||
| 38 | + } | ||
| 39 | + } | ||
| 40 | + | ||
| 41 | + int x = 0; | ||
| 42 | + int y = new Double((counts[0]/max)*panelHeight).intValue(); | ||
| 43 | + for (int i = 1; i < countSize; i++) { | ||
| 44 | + int nextX = (int) (panelWidth * i / countSize); | ||
| 45 | + int nextY = new Double((counts[i]/max)*panelHeight).intValue(); | ||
| 46 | + g2d.drawLine(x, y, nextX, nextY); | ||
| 47 | + x = nextX; | ||
| 48 | + y = nextY; | ||
| 49 | + } | ||
| 50 | + } | ||
| 51 | + } | ||
| 52 | + | ||
| 53 | +} |
lh-modules/lh-smart-feeder/src/main/java/com/zhonglai/luhui/smart/feeder/service/FileLook.java
0 → 100644
| 1 | +package com.zhonglai.luhui.smart.feeder.service; | ||
| 2 | + | ||
| 3 | +import com.zhonglai.luhui.smart.feeder.config.OpenCVConfig; | ||
| 4 | +import com.zhonglai.luhui.smart.feeder.draw.FishRegionPanel; | ||
| 5 | +import org.opencv.core.*; | ||
| 6 | +import org.opencv.imgproc.Imgproc; | ||
| 7 | +import org.opencv.videoio.VideoCapture; | ||
| 8 | + | ||
| 9 | +import java.util.ArrayList; | ||
| 10 | +import java.util.List; | ||
| 11 | + | ||
| 12 | +public class FileLook { | ||
| 13 | + public static void main(String[] args) { | ||
| 14 | + OpenCVConfig.loadOpenCv(args); | ||
| 15 | + VideoCapture videoCapture = new VideoCapture(); | ||
| 16 | + boolean isopen = videoCapture.open("C:/Users/123/Pictures/1.mp4"); | ||
| 17 | + if(isopen) | ||
| 18 | + { | ||
| 19 | + identifyFishRegionOnWaterSurface(videoCapture); | ||
| 20 | + }else { | ||
| 21 | + System.out.println("无法读取视频帧"); | ||
| 22 | + } | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + /** | ||
| 26 | + * 先识别水面,然后再在水面区域查找鱼群区域 | ||
| 27 | + */ | ||
| 28 | + private static void identifyFishRegionOnWaterSurface(VideoCapture videoCapture) { | ||
| 29 | + // 读取第一帧并获取视频大小 | ||
| 30 | + Mat previousFrame = new Mat(); | ||
| 31 | + if (!videoCapture.read(previousFrame)) { | ||
| 32 | + System.out.println("无法读取视频帧"); | ||
| 33 | + return; | ||
| 34 | + } | ||
| 35 | + | ||
| 36 | + FishRegionPanel fishRegionPanel = new FishRegionPanel(); | ||
| 37 | + | ||
| 38 | + // 转换为灰度图像 | ||
| 39 | + Mat previousGray = new Mat(); | ||
| 40 | + Imgproc.cvtColor(previousFrame, previousGray, Imgproc.COLOR_BGR2GRAY); | ||
| 41 | + | ||
| 42 | + // 识别水面区域 | ||
| 43 | + Rect waterSurfaceRegion = identifyWaterSurface(previousGray); | ||
| 44 | + | ||
| 45 | + // 逐帧处理视频 | ||
| 46 | + Mat frame = new Mat(); | ||
| 47 | + while (videoCapture.read(frame)) { | ||
| 48 | + // 在水面区域查找鱼群区域 | ||
| 49 | + Mat fishRegion = findFishRegion(frame, waterSurfaceRegion); | ||
| 50 | + | ||
| 51 | + // 绘制鱼群变化曲线 | ||
| 52 | +// fishRegionPanel.displayImage(fishRegion); | ||
| 53 | + fishRegionPanel.dispSrcImage(frame); | ||
| 54 | + } | ||
| 55 | + } | ||
| 56 | + | ||
| 57 | + private static Rect identifyWaterSurface(Mat grayImage) { | ||
| 58 | + // 进行自适应阈值分割,根据水面的亮度特征,将水面与其他区域分离 | ||
| 59 | + Mat binaryImage = new Mat(); | ||
| 60 | + Imgproc.adaptiveThreshold(grayImage, binaryImage, 255, Imgproc.ADAPTIVE_THRESH_MEAN_C, Imgproc.THRESH_BINARY_INV, 11, 5); | ||
| 61 | + | ||
| 62 | + // 进行形态学操作,去除噪点 | ||
| 63 | + int kernelSize = 5; // 调整内核大小 | ||
| 64 | + Mat kernel = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(kernelSize, kernelSize)); | ||
| 65 | + Imgproc.morphologyEx(binaryImage, binaryImage, Imgproc.MORPH_OPEN, kernel); | ||
| 66 | + | ||
| 67 | + // 查找水面区域的轮廓 | ||
| 68 | + List<MatOfPoint> contours = new ArrayList<>(); | ||
| 69 | + Mat hierarchy = new Mat(); | ||
| 70 | + Imgproc.findContours(binaryImage, contours, hierarchy, Imgproc.RETR_EXTERNAL, Imgproc.CHAIN_APPROX_SIMPLE); | ||
| 71 | + | ||
| 72 | + // 获取最大的轮廓区域作为水面区域 | ||
| 73 | + double maxContourArea = -1; | ||
| 74 | + Rect waterSurfaceRegion = null; | ||
| 75 | + for (MatOfPoint contour : contours) { | ||
| 76 | + double contourArea = Imgproc.contourArea(contour); | ||
| 77 | + if (contourArea > maxContourArea) { | ||
| 78 | + waterSurfaceRegion = Imgproc.boundingRect(contour); | ||
| 79 | + maxContourArea = contourArea; | ||
| 80 | + } | ||
| 81 | + } | ||
| 82 | + | ||
| 83 | + return waterSurfaceRegion; | ||
| 84 | + } | ||
| 85 | + | ||
| 86 | + | ||
| 87 | + private static Mat findFishRegion(Mat frame, Rect waterSurfaceRegion) { | ||
| 88 | + // 在水面区域查找鱼群区域 | ||
| 89 | + Mat fishRegion = frame.submat(waterSurfaceRegion); | ||
| 90 | + | ||
| 91 | + // 绘制绿色线框 | ||
| 92 | + Scalar green = new Scalar(0, 255, 0); // 绿色 | ||
| 93 | + Imgproc.rectangle(frame, waterSurfaceRegion, green, 2); // 绘制矩形框 | ||
| 94 | + | ||
| 95 | + // 转换为灰度图像 | ||
| 96 | + Mat grayImage = new Mat(); | ||
| 97 | + Imgproc.cvtColor(fishRegion, grayImage, Imgproc.COLOR_BGR2GRAY); | ||
| 98 | + | ||
| 99 | + // 进行亮度过滤 | ||
| 100 | + double brightnessThreshold = 100; // 亮度阈值 | ||
| 101 | + Mat binaryImage = new Mat(); | ||
| 102 | + Imgproc.threshold(grayImage, binaryImage, brightnessThreshold, 255, Imgproc.THRESH_BINARY); | ||
| 103 | + | ||
| 104 | + // 进行透明度过滤(如果需要) | ||
| 105 | + // 如果图像中有透明度通道(例如RGBA图像),可以提取透明度通道并进行阈值过滤 | ||
| 106 | + // 如果图像中没有透明度通道,则可以跳过这部分代码 | ||
| 107 | + Mat alphaChannel = new Mat(); | ||
| 108 | + if (fishRegion.channels() == 4) { | ||
| 109 | + Core.extractChannel(fishRegion, alphaChannel, 3); // 提取透明度通道 | ||
| 110 | + double alphaThreshold = 100; // 透明度阈值 | ||
| 111 | + Mat filteredImage = new Mat(); | ||
| 112 | + Imgproc.threshold(alphaChannel, filteredImage, alphaThreshold, 255, Imgproc.THRESH_BINARY); | ||
| 113 | + Core.bitwise_and(binaryImage, filteredImage, binaryImage); // 组合亮度过滤和透明度过滤的结果 | ||
| 114 | + } | ||
| 115 | + | ||
| 116 | + // 对二值图像进行形态学操作,去除噪点 | ||
| 117 | + int kernelSize = 3; // 调整内核大小 | ||
| 118 | + Mat kernel = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new org.opencv.core.Size(kernelSize, kernelSize)); | ||
| 119 | + Imgproc.morphologyEx(binaryImage, binaryImage, Imgproc.MORPH_OPEN, kernel); | ||
| 120 | + | ||
| 121 | + // 将二值图像转换回BGR图像 | ||
| 122 | + Mat filteredRegion = new Mat(); | ||
| 123 | + Imgproc.cvtColor(binaryImage, filteredRegion, Imgproc.COLOR_GRAY2BGR); | ||
| 124 | + | ||
| 125 | + return filteredRegion; | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | +} |
| 1 | package com.zhonglai.luhui.smart.feeder.service; | 1 | package com.zhonglai.luhui.smart.feeder.service; |
| 2 | 2 | ||
| 3 | import com.zhonglai.luhui.smart.feeder.config.OpenCVConfig; | 3 | import com.zhonglai.luhui.smart.feeder.config.OpenCVConfig; |
| 4 | -import com.zhonglai.luhui.smart.feeder.util.OpenCVUtils; | ||
| 5 | -import org.opencv.core.Core; | ||
| 6 | -import org.opencv.core.Mat; | 4 | +import com.zhonglai.luhui.smart.feeder.draw.FishRegionPanel; |
| 5 | +import org.opencv.core.*; | ||
| 7 | import org.opencv.videoio.VideoCapture; | 6 | import org.opencv.videoio.VideoCapture; |
| 8 | import org.opencv.imgproc.Imgproc; | 7 | import org.opencv.imgproc.Imgproc; |
| 9 | -import org.opencv.core.MatOfPoint; | ||
| 10 | -import org.opencv.core.Scalar; | ||
| 11 | -import org.opencv.core.CvType; | ||
| 12 | -import org.opencv.core.Rect; | ||
| 13 | 8 | ||
| 14 | -import javax.swing.*; | ||
| 15 | -import java.awt.image.BufferedImage; | ||
| 16 | import java.util.ArrayList; | 9 | import java.util.ArrayList; |
| 10 | +import java.util.Arrays; | ||
| 11 | +import java.util.Collections; | ||
| 17 | import java.util.List; | 12 | import java.util.List; |
| 18 | 13 | ||
| 19 | public class OpenCVService { | 14 | public class OpenCVService { |
| 15 | + /** | ||
| 16 | + * 反光阈值(reflectionThreshold)被设置为100。这意味着所有灰度值低于100的像素都会被设置为0(黑色),灰度值大于或等于100的像素都会被设置为255(白色)。如果你的图像中的对象或区域的灰度值接近或低于这个阈值,它们可能会被排除在二值图像之外。尝试调整这个阈值可能有助于改善结果 | ||
| 17 | + */ | ||
| 18 | + public static int reflectionThreshold = 100; // 反光阈值 | ||
| 19 | + public static int kernelSize = 3; // 去噪调整内核大小,用来消除小的物体或噪声 | ||
| 20 | + | ||
| 20 | public static void main(String[] args) { | 21 | public static void main(String[] args) { |
| 21 | OpenCVConfig.loadOpenCv(args); | 22 | OpenCVConfig.loadOpenCv(args); |
| 22 | - readVideoCaptureForVideo("C:/Users/123/Pictures/图片识别/6月30日.mp4"); | 23 | + readVideoCaptureForVideo("C:/Users/123/Pictures/1.mp4"); |
| 23 | } | 24 | } |
| 25 | + | ||
| 24 | public static void readVideoCaptureForVideo(String videoPath) | 26 | public static void readVideoCaptureForVideo(String videoPath) |
| 25 | { | 27 | { |
| 26 | // 创建VideoCapture对象 | 28 | // 创建VideoCapture对象 |
| @@ -32,86 +34,168 @@ public class OpenCVService { | @@ -32,86 +34,168 @@ public class OpenCVService { | ||
| 32 | System.out.println("无法打开视频文件"); | 34 | System.out.println("无法打开视频文件"); |
| 33 | return; | 35 | return; |
| 34 | } | 36 | } |
| 35 | - // 背景帧 | ||
| 36 | - Mat backgroundFrame = new Mat(); | 37 | + brightnessIdentifyFishRegion(videoCapture); |
| 37 | 38 | ||
| 38 | - // 初始阈值范围 | ||
| 39 | - double minAreaThreshold = Double.MAX_VALUE; | ||
| 40 | - double maxAreaThreshold = 0; | 39 | + // 释放资源 |
| 40 | + videoCapture.release(); | ||
| 41 | + } | ||
| 42 | + | ||
| 43 | + /** | ||
| 44 | + * 亮度查找水面,透明度过滤鱼群 | ||
| 45 | + */ | ||
| 46 | + private static void brightnessIdentifyFishRegion(VideoCapture videoCapture) | ||
| 47 | + { | ||
| 48 | + // 读取第一帧并获取视频大小 | ||
| 49 | + Mat previousFrame = new Mat(); | ||
| 50 | + if (!videoCapture.read(previousFrame)) { | ||
| 51 | + System.out.println("无法读取视频帧"); | ||
| 52 | + return; | ||
| 53 | + } | ||
| 54 | + // 获取水域轮廓 | ||
| 55 | + MatOfPoint largestContour = getDefaultMatOfPoint(previousFrame); | ||
| 56 | + | ||
| 57 | + //画板 | ||
| 58 | + FishRegionPanel fishRegionPanel = new FishRegionPanel(); | ||
| 41 | 59 | ||
| 42 | // 逐帧处理视频 | 60 | // 逐帧处理视频 |
| 43 | Mat frame = new Mat(); | 61 | Mat frame = new Mat(); |
| 44 | while (videoCapture.read(frame)) { | 62 | while (videoCapture.read(frame)) { |
| 45 | - // 背景差分 | ||
| 46 | - Mat diffFrame = new Mat(); | ||
| 47 | - Core.absdiff(frame, backgroundFrame, diffFrame); | 63 | + //抠图 |
| 64 | + Mat shuiyu = matting(frame,largestContour); | ||
| 65 | + | ||
| 66 | + // 2. 转换为灰度图像 | ||
| 67 | + Mat gray = new Mat(); | ||
| 68 | + Imgproc.cvtColor(shuiyu, gray, Imgproc.COLOR_BGR2GRAY); | ||
| 48 | 69 | ||
| 49 | - // 灰度转换 | ||
| 50 | - Mat grayFrame = new Mat(); | ||
| 51 | - Imgproc.cvtColor(diffFrame, grayFrame, Imgproc.COLOR_BGR2GRAY); | 70 | + // 3. 进行阈值分割以得到二值图像 |
| 71 | + Mat binaryImage = new Mat(); | ||
| 72 | + Imgproc.threshold(gray, binaryImage, 100, 255, Imgproc.THRESH_BINARY); | ||
| 52 | 73 | ||
| 53 | - // 阈值处理 | ||
| 54 | - Mat thresholdFrame = new Mat(); | ||
| 55 | - Imgproc.threshold(grayFrame, thresholdFrame, 30, 255, Imgproc.THRESH_BINARY); | 74 | + List<MatOfPoint> contours = new ArrayList<>(); // 用于存储找到的轮廓 |
| 75 | + Mat hierarchy = new Mat(); // 轮廓的层次结构 | ||
| 56 | 76 | ||
| 57 | - // 边缘检测 | ||
| 58 | - Mat edges = new Mat(); | ||
| 59 | - Imgproc.Canny(thresholdFrame, edges, 100, 200); | 77 | + // 在水域二值图像中找所有轮廓 |
| 78 | + Imgproc.findContours(binaryImage, contours, hierarchy, Imgproc.RETR_LIST, Imgproc.CHAIN_APPROX_SIMPLE); | ||
| 79 | + | ||
| 80 | + //计算大小 | ||
| 81 | + double area = getArea(contours); | ||
| 82 | + | ||
| 83 | + //标注识别对象 | ||
| 84 | + Imgproc.drawContours(frame, contours, -1, new Scalar(0, 0, 255), 2); | ||
| 85 | + Imgproc.drawContours(frame, Arrays.asList(new MatOfPoint[]{largestContour}), 0, new Scalar(0, 255, 0), 2); | ||
| 86 | + | ||
| 87 | + // 显示图像 | ||
| 88 | + // 在图像上显示结果 | ||
| 89 | + fishRegionPanel.displayImage(binaryImage); | ||
| 90 | + fishRegionPanel.dispSrcImage(frame); | ||
| 91 | + // 绘制鱼群变化曲线 | ||
| 92 | + fishRegionPanel.addFishCount(new Double(area).intValue()); | ||
| 93 | + | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + } | ||
| 60 | 97 | ||
| 61 | - // 轮廓检测 | 98 | + /** |
| 99 | + * 获取标准水域轮廓 | ||
| 100 | + * @param previousFrame | ||
| 101 | + * @return | ||
| 102 | + */ | ||
| 103 | + private static MatOfPoint getDefaultMatOfPoint(Mat previousFrame) | ||
| 104 | + { | ||
| 105 | + Mat firstBinaryImage = waterBybinary(previousFrame); | ||
| 106 | + // 绘制白色区域的轮廓 | ||
| 62 | List<MatOfPoint> contours = new ArrayList<>(); | 107 | List<MatOfPoint> contours = new ArrayList<>(); |
| 63 | Mat hierarchy = new Mat(); | 108 | Mat hierarchy = new Mat(); |
| 64 | - Imgproc.findContours(edges, contours, hierarchy, Imgproc.RETR_EXTERNAL, Imgproc.CHAIN_APPROX_SIMPLE); | 109 | + Imgproc.findContours(firstBinaryImage, contours, hierarchy, Imgproc.RETR_EXTERNAL, Imgproc.CHAIN_APPROX_SIMPLE); |
| 110 | + // 找到最大区域 | ||
| 111 | + double maxArea = 0; | ||
| 112 | + int maxAreaIndex = -1; | ||
| 113 | + for (int i = 0; i < contours.size(); i++) { | ||
| 114 | + double area = Imgproc.contourArea(contours.get(i)); | ||
| 115 | + if (area > maxArea) { | ||
| 116 | + maxArea = area; | ||
| 117 | + maxAreaIndex = i; | ||
| 118 | + } | ||
| 119 | + } | ||
| 120 | + // 获取最大区域的轮廓 | ||
| 121 | + MatOfPoint largestContour = contours.get(maxAreaIndex); | ||
| 122 | + return largestContour; | ||
| 123 | + } | ||
| 65 | 124 | ||
| 66 | - // 更新阈值范围 | ||
| 67 | - for (MatOfPoint contour : contours) { | ||
| 68 | - double area = Imgproc.contourArea(contour); | ||
| 69 | - if (area > maxAreaThreshold) { | ||
| 70 | - maxAreaThreshold = area; | 125 | + private static double getArea(List<MatOfPoint> contours) { |
| 126 | + // 找到最大区域 | ||
| 127 | + double maxArea = 0; | ||
| 128 | + int maxAreaIndex = -1; | ||
| 129 | + for (int i = 0; i < contours.size(); i++) { | ||
| 130 | + double area = Imgproc.contourArea(contours.get(i)); | ||
| 131 | + if (area > maxArea) { | ||
| 132 | + maxArea = area; | ||
| 133 | + maxAreaIndex = i; | ||
| 71 | } | 134 | } |
| 72 | - if (area < minAreaThreshold) { | ||
| 73 | - minAreaThreshold = area; | ||
| 74 | } | 135 | } |
| 136 | + if(-1 != maxAreaIndex) | ||
| 137 | + { | ||
| 138 | + contours.remove(maxAreaIndex); | ||
| 139 | + } | ||
| 140 | + // 返回总面积 | ||
| 141 | + return maxArea; | ||
| 75 | } | 142 | } |
| 76 | 143 | ||
| 77 | - // 根据阈值范围选择适当的阈值 | ||
| 78 | - double thresholdValue = (maxAreaThreshold + minAreaThreshold) / 2.0; | ||
| 79 | 144 | ||
| 80 | - // 绘制轮廓 | ||
| 81 | - Mat contourImage = new Mat(frame.size(), CvType.CV_8UC3, new Scalar(0, 0, 0)); | ||
| 82 | - Imgproc.drawContours(contourImage, contours, -1, new Scalar(0, 255, 0), 2); | 145 | + private static Mat matting(Mat frame,MatOfPoint largestContour) |
| 146 | + { | ||
| 147 | + // 创建一个与原始图像相同大小的新Mat,用于提取图像区域 | ||
| 148 | + Mat extractedRegion = Mat.zeros(frame.size(), frame.type()); | ||
| 83 | 149 | ||
| 84 | - // 提取鱼群区域 | ||
| 85 | - for (MatOfPoint contour : contours) { | ||
| 86 | - double area = Imgproc.contourArea(contour); | ||
| 87 | - if (area > thresholdValue) { | ||
| 88 | - // 对于满足面积阈值的轮廓,可以进一步处理或分析 | ||
| 89 | - // 例如,计算鱼群数量、中心位置等信息 | ||
| 90 | - // ... | 150 | + // 将指定的轮廓绘制到新的Mat上 |
| 151 | + Imgproc.drawContours(extractedRegion, Collections.singletonList(largestContour), 0, new Scalar(255, 255, 255), -1); | ||
| 91 | 152 | ||
| 92 | - // 在原图上绘制鱼群区域 | ||
| 93 | - Rect boundingRect = Imgproc.boundingRect(contour); | ||
| 94 | - Imgproc.rectangle(frame, boundingRect.tl(), boundingRect.br(), new Scalar(0, 255, 0), 2); | ||
| 95 | - } | ||
| 96 | - } | ||
| 97 | - // 重置阈值范围 | ||
| 98 | - minAreaThreshold = Double.MAX_VALUE; | ||
| 99 | - maxAreaThreshold = 0; | 153 | + // 使用按位与操作提取对应的图像区域 |
| 154 | + Mat extractedImage = new Mat(); | ||
| 155 | + Core.bitwise_and(frame, extractedRegion, extractedImage); | ||
| 100 | 156 | ||
| 101 | - // 在图像上显示结果 | ||
| 102 | - displayImage(frame); | 157 | + return extractedImage; |
| 103 | } | 158 | } |
| 104 | 159 | ||
| 160 | + /** | ||
| 161 | + * 根据反光查找水面 | ||
| 162 | + * @param frame | ||
| 163 | + * @return | ||
| 164 | + */ | ||
| 165 | + private static Mat waterBybinary(Mat frame) { | ||
| 166 | + // 将加载的图像转换为灰度图像,以便进行亮度或反光的分析 | ||
| 167 | + Mat grayImage = new Mat(); | ||
| 168 | + Imgproc.cvtColor(frame, grayImage, Imgproc.COLOR_BGR2GRAY); | ||
| 169 | + | ||
| 170 | + // 检测反光 | ||
| 171 | + Mat binaryImage = new Mat(); | ||
| 172 | + double maxValue = 255; | ||
| 173 | + Imgproc.threshold(grayImage, binaryImage, reflectionThreshold, maxValue, Imgproc.THRESH_BINARY); | ||
| 174 | + | ||
| 175 | + // 进行形态学操作,去除噪点 | ||
| 176 | + Mat kernel = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, new Size(kernelSize, kernelSize)); | ||
| 177 | + Imgproc.morphologyEx(binaryImage, binaryImage, Imgproc.MORPH_OPEN, kernel); | ||
| 178 | + | ||
| 179 | + return binaryImage; | ||
| 105 | } | 180 | } |
| 106 | 181 | ||
| 107 | - // 显示图像 | ||
| 108 | - private static void displayImage(Mat image) { | ||
| 109 | - // 将Mat图像转换为BufferedImage | ||
| 110 | - BufferedImage bufferedImage = OpenCVUtils.matToBufferedImage(image); | 182 | + private static List<MatOfPoint> fishByWater(Mat frame) |
| 183 | + { | ||
| 184 | + // 2. 转换为灰度图像 | ||
| 185 | + Mat gray = new Mat(); | ||
| 186 | + Imgproc.cvtColor(frame, gray, Imgproc.COLOR_BGR2GRAY); | ||
| 187 | + | ||
| 188 | + // 3. 进行阈值分割以得到二值图像 | ||
| 189 | + Mat binary = new Mat(); | ||
| 190 | + Imgproc.threshold(gray, binary, 100, 255, Imgproc.THRESH_BINARY); | ||
| 111 | 191 | ||
| 112 | - // 在标签上显示图像 | ||
| 113 | - new ImageIcon(bufferedImage); | 192 | + // 4. 查找轮廓 |
| 193 | + List<MatOfPoint> contours = new ArrayList<>(); | ||
| 194 | + Mat hierarchy = new Mat(); | ||
| 195 | + Imgproc.findContours(binary, contours, hierarchy, Imgproc.RETR_EXTERNAL, Imgproc.CHAIN_APPROX_SIMPLE); | ||
| 196 | + | ||
| 197 | + return contours; | ||
| 114 | 198 | ||
| 115 | - // 更新窗口 | ||
| 116 | } | 199 | } |
| 200 | + | ||
| 117 | } | 201 | } |
| @@ -18,7 +18,7 @@ public class VideoUtil { | @@ -18,7 +18,7 @@ public class VideoUtil { | ||
| 18 | * @throws Exception | 18 | * @throws Exception |
| 19 | */ | 19 | */ |
| 20 | public static void fetchPic(File file, String framefile, int second) throws Exception{ | 20 | public static void fetchPic(File file, String framefile, int second) throws Exception{ |
| 21 | - FFmpegFrameGrabber ff = new FFmpegFrameGrabber(file); | 21 | + FFmpegFrameGrabber ff = new FFmpegFrameGrabber( file); |
| 22 | ff.start(); | 22 | ff.start(); |
| 23 | int lenght = ff.getLengthInAudioFrames(); | 23 | int lenght = ff.getLengthInAudioFrames(); |
| 24 | System.out.println(ff.getFrameRate()); | 24 | System.out.println(ff.getFrameRate()); |
| @@ -32,6 +32,7 @@ public class VideoUtil { | @@ -32,6 +32,7 @@ public class VideoUtil { | ||
| 32 | int i = 0; | 32 | int i = 0; |
| 33 | Frame frame = null; | 33 | Frame frame = null; |
| 34 | while (i < lenght) { | 34 | while (i < lenght) { |
| 35 | + try { | ||
| 35 | frame = ff.grabImage(); | 36 | frame = ff.grabImage(); |
| 36 | if (i>=(int) (ff.getFrameRate()*second)&&frame.image != null) { | 37 | if (i>=(int) (ff.getFrameRate()*second)&&frame.image != null) { |
| 37 | System.out.print(i+","); | 38 | System.out.print(i+","); |
| @@ -42,6 +43,11 @@ public class VideoUtil { | @@ -42,6 +43,11 @@ public class VideoUtil { | ||
| 42 | second++; | 43 | second++; |
| 43 | } | 44 | } |
| 44 | i++; | 45 | i++; |
| 46 | + }catch (Exception e) | ||
| 47 | + { | ||
| 48 | + System.out.println(e); | ||
| 49 | + } | ||
| 50 | + | ||
| 45 | } | 51 | } |
| 46 | ff.stop(); | 52 | ff.stop(); |
| 47 | } | 53 | } |
| @@ -88,7 +94,7 @@ public class VideoUtil { | @@ -88,7 +94,7 @@ public class VideoUtil { | ||
| 88 | public static void main(String[] args){ | 94 | public static void main(String[] args){ |
| 89 | try { | 95 | try { |
| 90 | OpenCVConfig.loadOpenCv(args); | 96 | OpenCVConfig.loadOpenCv(args); |
| 91 | - File file = new File("C:\\Users\\123\\Pictures\\图片识别\\20210107_100743.mp4"); | 97 | + File file = new File("C:/Users/123/Pictures/1.mp4"); |
| 92 | VideoUtil.fetchPic(file,"C:\\Users\\123\\Pictures\\图片识别\\1\\",100); | 98 | VideoUtil.fetchPic(file,"C:\\Users\\123\\Pictures\\图片识别\\1\\",100); |
| 93 | System.out.println(VideoUtil.getVideoTime(file)); | 99 | System.out.println(VideoUtil.getVideoTime(file)); |
| 94 | } catch (Exception e) { | 100 | } catch (Exception e) { |
-
请 注册 或 登录 后发表评论