package com.framsticks.communication; import com.framsticks.communication.queries.ApplicationRequest; import com.framsticks.communication.queries.RegistrationRequest; import com.framsticks.communication.queries.UseRequest; import com.framsticks.communication.queries.VersionRequest; import com.framsticks.communication.util.LoggingStateCallback; import com.framsticks.params.ListSource; import com.framsticks.util.*; import com.framsticks.util.dispatching.AtOnceDispatcher; import com.framsticks.util.dispatching.Dispatcher; import com.framsticks.util.dispatching.Dispatching; import com.framsticks.util.dispatching.JoinableState; import com.framsticks.util.lang.Pair; import com.framsticks.util.lang.Strings; import org.apache.log4j.Level; import org.apache.log4j.Logger; import java.io.IOException; import java.net.Socket; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; import com.framsticks.util.dispatching.RunAt; /** * @author Piotr Sniegowski */ public class ClientConnection extends Connection { private final static Logger log = Logger.getLogger(ClientConnection.class); protected final Map> subscriptions = new HashMap<>(); /** * @param connectedFunctor the connectedFunctor to set */ public void setConnectedFunctor(StateFunctor connectedFunctor) { this.connectedFunctor = connectedFunctor; } public ClientConnection(String address) { super(address, "client connection"); Matcher matcher = addressPattern.matcher(address); if (!matcher.matches()) { log.fatal("invalid address: " + address); hostName = null; port = 0; return; } hostName = matcher.group(1); port = matcher.group(3) != null ? Integer.parseInt(matcher.group(3)) : 9009; } protected StateFunctor connectedFunctor; private static abstract class InboundMessage { protected String currentFilePath; protected List currentFileContent; protected final List files = new ArrayList(); public abstract void eof(); protected void initCurrentFile(String path) { currentFileContent = new LinkedList(); currentFilePath = path; } protected void finishCurrentFile() { if (currentFileContent == null) { return; } files.add(new File(currentFilePath, new ListSource(currentFileContent))); currentFilePath = null; currentFileContent = null; } public abstract void startFile(String path); public final void addLine(String line) { assert line != null; assert currentFileContent != null; currentFileContent.add(line); } public List getFiles() { return files; } } private static class EventFire extends InboundMessage { public final Subscription subscription; private EventFire(Subscription subscription) { this.subscription = subscription; } public void startFile(String path) { assert path == null; initCurrentFile(null); } @Override public void eof() { finishCurrentFile(); subscription.dispatchCall(getFiles()); } } private static class SentQuery extends InboundMessage { Request request; ResponseCallback callback; Dispatcher dispatcher; public void startFile(String path) { finishCurrentFile(); if (!Strings.notEmpty(path)) { assert request instanceof ApplicationRequest; path = ((ApplicationRequest) request).getPath(); } Strings.assureNotEmpty(path); initCurrentFile(path); } public void eof() { assert Strings.notEmpty(currentFilePath); finishCurrentFile(); //no-operation } @Override public String toString() { return request.toString(); } public void dispatchResponseProcess(final Response response) { Dispatching.dispatchIfNotActive(dispatcher, new RunAt() { @Override public void run() { callback.process(response); } }); } } private Map> queryMap = new HashMap<>(); protected final String hostName; protected final int port; private static Pattern addressPattern = Pattern.compile("^([^:]*)(:([0-9]+))?$"); private SentQuery currentlySentQuery; public void send(Request request, ResponseCallback callback) { //TODO RunAt send(request, AtOnceDispatcher.getInstance(), callback); } public void send(Request request, Dispatcher dispatcher, ResponseCallback callback) { if (getState().ordinal() > JoinableState.RUNNING.ordinal()) { log.fatal("not connected"); return; } final SentQuery sentQuery = new SentQuery(); sentQuery.request = request; sentQuery.callback = callback; sentQuery.dispatcher = dispatcher; senderThread.dispatch(new RunAt() { @Override public void run() { Integer id; synchronized (ClientConnection.this) { while (!(requestIdEnabled || currentlySentQuery == null)) { try { ClientConnection.this.wait(); } catch (InterruptedException ignored) { break; } } if (requestIdEnabled) { queryMap.put(nextQueryId, sentQuery); id = nextQueryId++; } else { currentlySentQuery = sentQuery; id = null; } } String command = sentQuery.request.getCommand(); StringBuilder message = new StringBuilder(); message.append(command); if (id != null) { message.append(" ").append(id); } message.append(" "); sentQuery.request.construct(message); String out = message.toString(); output.println(out); output.flush(); log.debug("sending query: " + out); } }); /* synchronized (this) { log.debug("queueing query: " + query); queryQueue.offer(sentQuery); notifyAll(); } */ } @Override public String toString() { return "client connection " + address; } public void subscribe(final String path, final Dispatcher dispatcher, final SubscriptionCallback callback) { send(new RegistrationRequest().path(path), new ResponseCallback() { @Override public void process(Response response) { if (!response.getOk()) { log.error("failed to register on event: " + path); callback.subscribed(null); return; } assert response.getFiles().isEmpty(); Subscription subscription = new Subscription(ClientConnection.this, path, response.getComment(), dispatcher); log.debug("registered on event: " + subscription); synchronized (subscriptions) { subscriptions.put(subscription.getRegisteredPath(), subscription); } subscription.setEventCallback(callback.subscribed(subscription)); if (subscription.getEventCallback() == null) { log.info("subscription for " + path + " aborted"); subscription.unsubscribe(new LoggingStateCallback(log, "abort subscription")); } } }); } public void negotiateProtocolVersion(StateFunctor stateFunctor) { protocolVersion = -1; sendQueryVersion(1, stateFunctor); } public void sendQueryVersion(final int version, StateFunctor stateFunctor) { send(new VersionRequest().version(version), new StateCallback(stateFunctor) { @Override public void callImpl() { protocolVersion = version; if (version < 4) { /** it is an implicit loop here*/ sendQueryVersion(version + 1, move()); return; } send(new UseRequest().feature("request_id"), new StateCallback(move()) { @Override public void handle(FramsticksException exception) { requestIdEnabled = false; log.fatal("protocol negotiation failed"); super.handle(new FramsticksException().msg("protocol negotiation failed").cause(exception)); } @Override public void callImpl() { requestIdEnabled = true; } }); } }); } private synchronized SentQuery fetchQuery(Integer id, boolean remove) { if (id == null) { if (requestIdEnabled) { return null; } SentQuery result = currentlySentQuery; if (remove) { currentlySentQuery = null; notifyAll(); } return result; } if (queryMap.containsKey(id)) { SentQuery result = queryMap.get(id); if (remove) { queryMap.remove(id); } return result; } return null; } private int nextQueryId = 0; protected void processMessage(InboundMessage inboundMessage) { if (inboundMessage == null) { log.error("failed to use any inbound message"); return; } String line; while (!(line = getLine()).startsWith("eof")) { // log.debug("line: " + line); inboundMessage.addLine(line); } inboundMessage.eof(); } protected void processEvent(String rest) { Matcher matcher = Request.EVENT_PATTERN.matcher(rest); if (!matcher.matches()) { log.error("invalid event line: " + rest); return; } Subscription subscription = subscriptions.get(matcher.group(1)); if (subscription == null) { log.error("non subscribed event: " + matcher.group(1)); return; } EventFire event = new EventFire(subscription); event.startFile(null); processMessage(event); } protected void processMessageStartingWith(String line) { try { Pair command = Request.takeIdentifier(line); if (command.first.equals("event")) { processEvent(command.second.toString()); return; } Pair rest = takeRequestId(command.second); if (command.first.equals("file")) { SentQuery sentQuery = fetchQuery(rest.first, false); sentQuery.startFile(rest.second.toString()); processMessage(sentQuery); return; } SentQuery sentQuery = fetchQuery(rest.first, true); if (sentQuery == null) { return; } log.debug("parsing response for request " + sentQuery); sentQuery.dispatchResponseProcess(new Response(command.first.equals("ok"), rest.second.toString(), sentQuery.getFiles())); } catch (FramsticksException e) { throw new FramsticksException().msg("failed to process message").arg("starting with line", line).cause(e); } } @Override protected void receiverThreadRoutine() { while (isRunning() && !isConnected()) { log.debug("connecting to " + address); try { socket = new Socket(hostName, port); } catch (IOException e) { log.info(this + " failed to connect (retrying): " + e); Dispatching.sleep(0.5); } } log.debug(this + " connected"); try { socket.setSoTimeout(500); setupStreams(); } catch (Exception e) { throw new FramsticksException().msg("failed to initialize socket").cause(e).arg("connection", this); } connectedFunctor.call(); while (isRunning() && isConnected()) { try { processMessageStartingWith(getLine()); } catch (Exception e) { log.log(isRunning() ? Level.ERROR : Level.DEBUG, "caught exception: ", e); break; } } interrupt(); finish(); } }